Browse Source

删除了 driver 中的 **_step,使用 model_call 和 get_model_call_fn 来代替;删除了 driver 中的所有 dataloaders

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
2924e2117f
29 changed files with 273 additions and 452 deletions
  1. +1
    -1
      fastNLP/core/callbacks/checkpoint_callback.py
  2. +17
    -32
      fastNLP/core/controllers/evaluator.py
  3. +60
    -37
      fastNLP/core/controllers/trainer.py
  4. +1
    -1
      fastNLP/core/controllers/utils/utils.py
  5. +31
    -91
      fastNLP/core/drivers/driver.py
  6. +7
    -7
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  7. +5
    -5
      fastNLP/core/drivers/jittor_driver/single_device.py
  8. +3
    -3
      fastNLP/core/drivers/paddle_driver/fleet.py
  9. +8
    -8
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  10. +8
    -8
      fastNLP/core/drivers/paddle_driver/single_device.py
  11. +7
    -7
      fastNLP/core/drivers/paddle_driver/utils.py
  12. +40
    -61
      fastNLP/core/drivers/torch_driver/ddp.py
  13. +26
    -70
      fastNLP/core/drivers/torch_driver/single_device.py
  14. +1
    -18
      fastNLP/core/drivers/torch_driver/torch_driver.py
  15. +6
    -51
      fastNLP/core/drivers/torch_driver/utils.py
  16. +5
    -5
      fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py
  17. +3
    -3
      fastNLP/core/log/logger.py
  18. +8
    -8
      tests/core/callbacks/test_checkpoint_callback_torch.py
  19. +1
    -1
      tests/core/callbacks/test_load_best_model_callback_torch.py
  20. +1
    -1
      tests/core/controllers/_test_distributed_launch_torch_1.py
  21. +1
    -1
      tests/core/controllers/_test_distributed_launch_torch_2.py
  22. +1
    -1
      tests/core/controllers/test_trainer_event_trigger.py
  23. +2
    -2
      tests/core/controllers/test_trainer_fleet.py
  24. +2
    -2
      tests/core/controllers/test_trainer_fleet_outside.py
  25. +7
    -7
      tests/core/controllers/test_trainer_paddle.py
  26. +8
    -8
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  27. +5
    -5
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  28. +7
    -7
      tests/core/drivers/paddle_driver/test_single_device.py
  29. +1
    -1
      tests/helpers/models/torch_model.py

+ 1
- 1
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -95,7 +95,7 @@ class CheckpointCallback(HasMonitorCallback):
if self.save_topk is not None: if self.save_topk is not None:
super().on_after_trainer_initialized(trainer, driver) super().on_after_trainer_initialized(trainer, driver)
if self.save_topk is not None and trainer.evaluator is None: if self.save_topk is not None and trainer.evaluator is None:
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.")
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.")


def on_validate_end(self, trainer, results): def on_validate_end(self, trainer, results):
if len(results) == 0: if len(results) == 0:


+ 17
- 32
fastNLP/core/controllers/evaluator.py View File

@@ -39,7 +39,7 @@ class Evaluator:
driver: Union[str, Driver] = 'single', driver: Union[str, Driver] = 'single',
device: Optional[Union[int, List[int], str]] = None, device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None, batch_step_fn: Optional[callable] = None,
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable
input_mapping: Optional[Union[Callable, Dict]] = None, input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False, model_wo_auto_param_call: bool = False,
@@ -58,14 +58,13 @@ class Evaluator:
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的
batch_step_fn 函数。 batch_step_fn 函数。
:param mode: 可选 ["validate", "test"], 当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数,
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`;
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数;
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`;
:param fp16: 是否使用 fp16 。 :param fp16: 是否使用 fp16 。
:param verbose: 是否打印 evaluate 的结果。 :param verbose: 是否打印 evaluate 的结果。
:param kwargs: :param kwargs:
@@ -87,9 +86,11 @@ class Evaluator:


self.model = model self.model = model
self.metrics = metrics self.metrics = metrics

self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs)


if dataloaders is None:
raise ValueError("Parameter `dataloaders` can not be None.")
self.dataloaders = dataloaders
self.device = device self.device = device
self.verbose = verbose self.verbose = verbose


@@ -97,21 +98,12 @@ class Evaluator:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.batch_step_fn = batch_step_fn self.batch_step_fn = batch_step_fn


self.mode = mode
assert mode in {'validate', 'test'}, "Parameter `mode` should only be 'validate' or 'test'."

self.input_mapping = input_mapping self.input_mapping = input_mapping
self.output_mapping = output_mapping self.output_mapping = output_mapping


if not isinstance(dataloaders, dict): if not isinstance(dataloaders, dict):
dataloaders = {None: dataloaders} dataloaders = {None: dataloaders}
if mode == "validate":
self._evaluate_step = self.driver.validate_step
self.driver.set_dataloader(validate_dataloaders=dataloaders)
else:
self._evaluate_step = self.driver.test_step
self.driver.set_dataloader(test_dataloaders=dataloaders)
self.mode = mode

self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn)
self.separator = kwargs.get('separator', '#') self.separator = kwargs.get('separator', '#')
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
@@ -123,10 +115,14 @@ class Evaluator:
self._metric_wrapper = None self._metric_wrapper = None
_ = self.metrics_wrapper # 触发检查 _ = self.metrics_wrapper # 触发检查


assert self.driver.has_validate_dataloaders() or self.driver.has_test_dataloaders()
self.driver.setup() self.driver.setup()
self.driver.barrier() self.driver.barrier()


if evaluate_fn is not None and not isinstance(evaluate_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
self.evaluate_fn = evaluate_fn

self.dataloaders = {} self.dataloaders = {}
for name, dl in dataloaders.items(): # 替换为正确的 sampler for name, dl in dataloaders.items(): # 替换为正确的 sampler
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
@@ -136,9 +132,10 @@ class Evaluator:
if self.progress_bar == 'auto': if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'


self.driver.check_evaluator_mode(self.mode)
self.driver.barrier() self.driver.barrier()


self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False)

def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
""" """
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。
@@ -156,11 +153,6 @@ class Evaluator:
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0."


if self.mode == 'validate':
assert self.driver.has_validate_dataloaders()
else:
assert self.driver.has_test_dataloaders()

metric_results = {} metric_results = {}
self.reset() self.reset()
evaluate_context = self.driver.get_evaluate_context() evaluate_context = self.driver.get_evaluate_context()
@@ -235,13 +227,6 @@ class Evaluator:
f_rich_progress.destroy_task(self._rich_task_id) f_rich_progress.destroy_task(self._rich_task_id)
delattr(self, '_rich_task_id') delattr(self, '_rich_task_id')


@property
def eval_dataloaders(self):
if self.mode == "validate":
return self.driver.validate_dataloaders
else:
return self.driver.test_dataloaders

@property @property
def evaluate_batch_loop(self): def evaluate_batch_loop(self):
return self._evaluate_batch_loop return self._evaluate_batch_loop
@@ -296,13 +281,13 @@ class Evaluator:


def evaluate_step(self, batch): def evaluate_step(self, batch):
""" """
将 batch 传递到model中进行处理,根据当前 mode 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再
返回。 返回。


:param batch: :param batch:
:return: :return:
""" """
outputs = self._evaluate_step(batch)
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn)
outputs = match_and_substitute_params(self.output_mapping, outputs) outputs = match_and_substitute_params(self.output_mapping, outputs)
return outputs return outputs




+ 60
- 37
fastNLP/core/controllers/trainer.py View File

@@ -41,19 +41,20 @@ class Trainer(TrainerEventTrigger):
optimizers, optimizers,
device: Optional[Union[int, List[int], str]] = "cpu", device: Optional[Union[int, List[int], str]] = "cpu",
n_epochs: int = 20, n_epochs: int = 20,
validate_dataloaders=None,
evaluate_dataloaders=None,
batch_step_fn: Optional[Callable] = None, batch_step_fn: Optional[Callable] = None,
validate_batch_step_fn: Optional[Callable] = None,
validate_mode: Union[str, callable] = 'validate',
evaluate_batch_step_fn: Optional[Callable] = None,
train_fn: Optional[str] = None,
evaluate_fn: Optional[str] = None,
callbacks: Union[List[Callback], Callback, None] = None, callbacks: Union[List[Callback], Callback, None] = None,
metrics: Optional[dict] = None, metrics: Optional[dict] = None,
validate_every: Optional[Union[int, callable]] = -1,
evaluate_every: Optional[Union[int, Callable]] = -1,
input_mapping: Optional[Union[Callable, Dict]] = None, input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False, model_wo_auto_param_call: bool = False,
accumulation_steps: int = 1, accumulation_steps: int = 1,
fp16: bool = False, fp16: bool = False,
monitor: Union[str, callable] = None,
monitor: Union[str, Callable] = None,
larger_better: bool = True, larger_better: bool = True,
marker: Optional[str] = None, marker: Optional[str] = None,
**kwargs **kwargs
@@ -79,19 +80,19 @@ class Trainer(TrainerEventTrigger):
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`;
5. None: 为None则不对模型进行任何处理; 5. None: 为None则不对模型进行任何处理;
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param validate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None; 为 None;
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 :param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和
`batch`;默认为 None; `batch`;默认为 None;
:param validate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的
两个参数必须为 `evaluator` 和 `batch`;默认为 None; 两个参数必须为 `evaluator` 和 `batch`;默认为 None;
:param validate_mode: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,其值应当为以下之一:["validate", "test"]
默认为 "validate";当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数,
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用哪一个函数,例如是 `model.train_step` 还是 `model.forward`
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数;
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似;
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
@@ -105,10 +106,10 @@ class Trainer(TrainerEventTrigger):
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`;
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False; :param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。 :param larger_better: monitor 的值是否是越大越好。
@@ -136,10 +137,15 @@ class Trainer(TrainerEventTrigger):
else: else:
self.driver_name = driver.__class__.__name__ self.driver_name = driver.__class__.__name__
self.device = device self.device = device
if train_dataloader is None:
raise ValueError("Parameter `train_dataloader` can not be None.")
self.train_dataloader = train_dataloader
self.evaluate_dataloaders = evaluate_dataloaders
self.optimizers = optimizers self.optimizers = optimizers
self.fp16 = fp16 self.fp16 = fp16
self.input_mapping = input_mapping self.input_mapping = input_mapping
self.output_mapping = output_mapping self.output_mapping = output_mapping
self.evaluate_fn = evaluate_fn


self.batch_step_fn = batch_step_fn self.batch_step_fn = batch_step_fn
if batch_step_fn is not None: if batch_step_fn is not None:
@@ -168,13 +174,13 @@ class Trainer(TrainerEventTrigger):
optimizers=optimizers, optimizers=optimizers,
device=device, device=device,
n_epochs=n_epochs, n_epochs=n_epochs,
validate_dataloaders=validate_dataloaders,
validate_dataloaders=evaluate_dataloaders,
batch_step_fn=batch_step_fn, batch_step_fn=batch_step_fn,
validate_batch_step_fn=validate_batch_step_fn,
validate_mode=validate_mode,
validate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
callbacks=callbacks, callbacks=callbacks,
metrics=metrics, metrics=metrics,
validate_every=validate_every,
validate_every=evaluate_every,
input_mapping=input_mapping, input_mapping=input_mapping,
output_mapping=output_mapping, output_mapping=output_mapping,
model_wo_auto_param_call=model_wo_auto_param_call, model_wo_auto_param_call=model_wo_auto_param_call,
@@ -185,9 +191,6 @@ class Trainer(TrainerEventTrigger):
) )
self.driver.set_optimizers(optimizers=optimizers) self.driver.set_optimizers(optimizers=optimizers)


if train_dataloader is not None:
self.driver.set_dataloader(train_dataloader=train_dataloader)

# 初始化 callback manager; # 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto'))
# 添加所有的函数式 callbacks; # 添加所有的函数式 callbacks;
@@ -213,25 +216,25 @@ class Trainer(TrainerEventTrigger):
_dist_sampler = None _dist_sampler = None


""" 设置内部的 Evaluator """ """ 设置内部的 Evaluator """
if metrics is None and validate_dataloaders is not None:
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.")


if metrics is not None and validate_dataloaders is None:
if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")


self.evaluator = None self.evaluator = None
self.monitor = monitor self.monitor = monitor
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None and validate_dataloaders is not None:
check_validate_every(validate_every)
if metrics is not None and evaluate_dataloaders is not None:
check_validate_every(evaluate_every)
self.evaluator = Evaluator( self.evaluator = Evaluator(
model=model, model=model,
dataloaders=validate_dataloaders,
dataloaders=evaluate_dataloaders,
metrics=metrics, metrics=metrics,
driver=self.driver, driver=self.driver,
device=device, device=device,
batch_step_fn=validate_batch_step_fn,
mode=validate_mode,
batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
input_mapping=input_mapping, input_mapping=input_mapping,
output_mapping=output_mapping, output_mapping=output_mapping,
fp16=fp16, fp16=fp16,
@@ -241,12 +244,16 @@ class Trainer(TrainerEventTrigger):
) )


self.metrics = metrics self.metrics = metrics
self.validate_every = validate_every
self.validate_every = evaluate_every


assert self.driver.has_train_dataloader()
self.driver.setup() self.driver.setup()
self.driver.barrier() self.driver.barrier()


if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
self.train_fn = train_fn

self.dataloader = self.train_dataloader self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader) self.driver.set_deterministic_dataloader(self.dataloader)


@@ -257,6 +264,7 @@ class Trainer(TrainerEventTrigger):
self.on_after_trainer_initialized(self.driver) self.on_after_trainer_initialized(self.driver)


self.driver.barrier() self.driver.barrier()
self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True)


def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
@@ -273,6 +281,7 @@ class Trainer(TrainerEventTrigger):
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch )
:return: :return:
""" """

if catch_KeyboardInterrupt is None: if catch_KeyboardInterrupt is None:
catch_KeyboardInterrupt = not self.driver.is_distributed() catch_KeyboardInterrupt = not self.driver.is_distributed()
else: else:
@@ -343,7 +352,8 @@ class Trainer(TrainerEventTrigger):
_validate_res: dict = validate_fn() _validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res) trainer.on_validate_end(_validate_res)


self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
if self.evaluator is not None:
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))


def step_validate(self): def step_validate(self):
""" """
@@ -489,11 +499,6 @@ class Trainer(TrainerEventTrigger):
self.has_checked_train_batch_loop = True self.has_checked_train_batch_loop = True


""" Trainer 需要的一些 property """ """ Trainer 需要的一些 property """

@property
def train_dataloader(self):
return self.driver.train_dataloader

@property @property
def driver(self): def driver(self):
return self._driver return self._driver
@@ -684,7 +689,7 @@ class Trainer(TrainerEventTrigger):


def train_step(self, batch): def train_step(self, batch):
with self.driver.auto_cast(): with self.driver.auto_cast():
outputs = self.driver.train_step(batch)
outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn)
outputs = match_and_substitute_params(self.output_mapping, outputs) outputs = match_and_substitute_params(self.output_mapping, outputs)
return outputs return outputs


@@ -814,6 +819,24 @@ class Trainer(TrainerEventTrigger):
def data_device(self): def data_device(self):
return self.driver.data_device return self.driver.data_device


""" dataloader property """

@property
def train_dataloader(self):
return self._train_dataloader

@train_dataloader.setter
def train_dataloader(self, train_dataloader):
self._train_dataloader = train_dataloader

@property
def evaluate_dataloaders(self):
return self._evaluate_dataloaders

@evaluate_dataloaders.setter
def evaluate_dataloaders(self, evaluate_dataloaders):
self._evaluate_dataloaders = evaluate_dataloaders









+ 1
- 1
fastNLP/core/controllers/utils/utils.py View File

@@ -128,6 +128,6 @@ class _TruncatedDataLoader:


def check_validate_every(validate_every): def check_validate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every): if callable(validate_every):
_check_valid_parameters_number(validate_every, expected_params=['trainer']) _check_valid_parameters_number(validate_every, expected_params=['trainer'])

+ 31
- 91
fastNLP/core/drivers/driver.py View File

@@ -1,7 +1,7 @@
import os import os
import signal import signal
import sys import sys
from typing import Any, Sequence, List, Optional, Callable, Dict, Union
from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -79,41 +79,44 @@ class Driver(ABC):
""" """


@abstractmethod @abstractmethod
def train_step(self, batch):
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
""" """
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程;
如果检测到用户模型实现了 train_step
通过调用 `fn` 来实现训练时的前向传播过程;
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的
函数;


:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:return: 返回由模型的 `train_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
:param fn: 由 Trainer 传入的用于网络前向传播一次的函数;
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
""" """
raise NotImplementedError("Each specific driver should implemented its own `train_step` function.")
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")


def validate_step(self, batch):
@abstractmethod
def get_model_call_fn(self, fn: str) -> Tuple:
""" """
通过调用模型自带的 `validate_step` 或者 `forward` 方法来实现模型评测的前向过程;
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数;
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用;


:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:return: 返回由模型的 `validate_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
"""
raise NotImplementedError("Each specific driver should implemented its own `validate_step` function.")
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;


def test_step(self, batch):
"""
通过调用模型自带的 `test_step` 或者 `forward` 方法来实现模型评测的前向过程;
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
函数,然后给出 warning;
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
可能需要额外标记最初传入 driver 的模型是哪种形式的;


:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:return: 返回由模型的 `test_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入
""" """
raise NotImplementedError("Each specific driver should implemented its own `test_step` function.")

def check_evaluator_mode(self, mode: str):
r"""
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为;
"""
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.")
raise NotImplementedError("Each specific driver should implemented its own `get_model_call_fn` function.")


@property @property
def model(self): def model(self):
@@ -123,59 +126,8 @@ class Driver(ABC):
def model(self, model): def model(self, model):
self._model = model self._model = model


@property
def train_dataloader(self):
return self._train_dataloader

@train_dataloader.setter
def train_dataloader(self, train_dataloader: Any):
self._train_dataloader = train_dataloader

@property
def validate_dataloaders(self):
return self._validate_dataloaders

@validate_dataloaders.setter
def validate_dataloaders(self, validate_dataloaders: Any):
self._validate_dataloaders = validate_dataloaders

@property
def test_dataloaders(self):
return self._test_dataloaders

@test_dataloaders.setter
def test_dataloaders(self, test_dataloaders: Any):
self._test_dataloaders = test_dataloaders

@property
def predict_dataloaders(self):
return self._predict_dataloaders

@predict_dataloaders.setter
def predict_dataloaders(self, predict_dataloaders: Any):
self._predict_dataloaders = predict_dataloaders

def set_dataloader(self, **kwargs):
r"""
设置训练或者检验过程中的数据;用于在 trainer 和 evaluator 中将数据 dataloader 挂载到每一个具体的 driver 上;

:param kwargs: 输入的数据,应当使用 'keyword-only' 的参数进行设置;
"""
if "train_dataloader" in kwargs:
self.train_dataloader = kwargs["train_dataloader"]
self._check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True)
if "validate_dataloaders" in kwargs:
self.validate_dataloaders = kwargs["validate_dataloaders"]
self._check_dataloader_legality(self.validate_dataloaders, "validate_dataloaders", is_train=False)
if "test_dataloaders" in kwargs:
self.test_dataloaders = kwargs["test_dataloaders"]
self._check_dataloader_legality(self.test_dataloaders, "test_dataloaders", is_train=False)
if "predict_dataloaders" in kwargs:
self.predict_dataloaders = kwargs["predict_dataloaders"]
self._check_dataloader_legality(self.predict_dataloaders, "predict_dataloaders", is_train=False)

@staticmethod @staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r""" r"""
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的
行为是不相同的; 行为是不相同的;
@@ -183,19 +135,7 @@ class Driver(ABC):
:param dataloader: 需要检测的输入的 `dataloader`; :param dataloader: 需要检测的输入的 `dataloader`;
:param dataloader_name: :param dataloader_name:
""" """
raise NotImplementedError("Each specific driver should implemented its own `_check_dataloader_legality` function.")

def has_train_dataloader(self):
return "_train_dataloader" in self.__dict__

def has_validate_dataloaders(self):
return "_validate_dataloaders" in self.__dict__

def has_test_dataloaders(self):
return "_test_dataloaders" in self.__dict__

def has_predict_dataloaders(self):
return "_predict_dataloaders" in self.__dict__
raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.")


@property @property
def optimizers(self) -> List: def optimizers(self) -> List:


+ 7
- 7
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -39,7 +39,7 @@ class JittorDriver(Driver):
self.grad_scaler = _grad_scaler() self.grad_scaler = _grad_scaler()


@staticmethod @staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
# 在fastnlp中实现了JittorDataLoader # 在fastnlp中实现了JittorDataLoader
# TODO: 是否允许传入Dataset? # TODO: 是否允许传入Dataset?
if is_train: if is_train:
@@ -64,18 +64,18 @@ class JittorDriver(Driver):
def check_evaluator_mode(self, mode: str): def check_evaluator_mode(self, mode: str):
model = self.unwrap_model() model = self.unwrap_model()
if mode == "validate": if mode == "validate":
if not hasattr(model, "validate_step"):
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"): if hasattr(model, "test_step"):
logger.warning_once( logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for"
"'evaluate_step'.")


else: else:
if not hasattr(model, "test_step"): if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for"
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.") "'test_step'.")


def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None):


+ 5
- 5
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -35,8 +35,8 @@ class JittorSingleDriver(JittorDriver):
model = self.unwrap_model() model = self.unwrap_model()
self._train_signature_fn = model.execute self._train_signature_fn = model.execute


if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None self._validate_signature_fn = None
elif hasattr(self.model, "test_step"): elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step self._validate_step = self.model.test_step
@@ -49,9 +49,9 @@ class JittorSingleDriver(JittorDriver):
if hasattr(self.model, "test_step"): if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step self._test_step = self.model.test_step
self._test_signature_fn = None self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else: else:
self._test_step = self.model self._test_step = self.model
model = self.unwrap_model() model = self.unwrap_model()


+ 3
- 3
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -118,11 +118,11 @@ class PaddleFleetDriver(PaddleDriver):
" call `forward` function instead of `train_step` and you should note that.") " call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)


if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning( logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your " "Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)


if hasattr(model, "test_step"): if hasattr(model, "test_step"):


+ 8
- 8
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -72,7 +72,7 @@ class PaddleDriver(Driver):
optimizer.clear_grad() optimizer.clear_grad()


@staticmethod @staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r""" r"""
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。
要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。
@@ -117,24 +117,24 @@ class PaddleDriver(Driver):


def check_evaluator_mode(self, mode: str): def check_evaluator_mode(self, mode: str):
r""" r"""
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为; 我们应当提醒用户这一行为;
""" """
model = self.unwrap_model() model = self.unwrap_model()
if mode == "validate": if mode == "validate":
if not hasattr(model, "validate_step"):
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"): if hasattr(model, "test_step"):
logger.warning( logger.warning(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" "are using 'Evaluator.validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")
"'evaluate_step'.")


else: else:
if not hasattr(model, "test_step"): if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for"
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.") "'test_step'.")


@staticmethod @staticmethod


+ 8
- 8
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -50,10 +50,10 @@ class PaddleSingleDriver(PaddleDriver):
self._train_step = self.model self._train_step = self.model
self._train_signature_fn = model.forward self._train_signature_fn = model.forward


if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `validate_step` method, which we can not call actually, we "
"will call `forward` function instead of `validate_step` and you should note that.")
"implements the `evaluate_step` method, which we can not call actually, we "
"will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model self._validate_step = self.model
self._validate_signature_fn = model.forward self._validate_signature_fn = model.forward


@@ -73,8 +73,8 @@ class PaddleSingleDriver(PaddleDriver):
model = self.unwrap_model() model = self.unwrap_model()
self._train_signature_fn = model.forward self._train_signature_fn = model.forward


if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None self._validate_signature_fn = None
elif hasattr(self.model, "test_step"): elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step self._validate_step = self.model.test_step
@@ -87,9 +87,9 @@ class PaddleSingleDriver(PaddleDriver):
if hasattr(self.model, "test_step"): if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step self._test_step = self.model.test_step
self._test_signature_fn = None self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else: else:
self._test_step = self.model self._test_step = self.model
model = self.unwrap_model() model = self.unwrap_model()


+ 7
- 7
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -108,11 +108,11 @@ class _FleetWrappingModel(Layer):
self._train_step = self.model self._train_step = self.model
self._train_signature_fn = model.forward self._train_signature_fn = model.forward


if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
logger.warning( logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your " "Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model self._validate_step = self.model
self._validate_signature_fn = model.forward self._validate_signature_fn = model.forward


@@ -131,7 +131,7 @@ class _FleetWrappingModel(Layer):
self._train_step = model self._train_step = model
self._train_signature_fn = model.forward self._train_signature_fn = model.forward


if hasattr(model, "validate_step"):
if hasattr(model, "evaluate_step"):
self._validate_step = model.validate_step self._validate_step = model.validate_step
self._validate_signature_fn = None self._validate_signature_fn = None
elif hasattr(model, "test_step"): elif hasattr(model, "test_step"):
@@ -144,7 +144,7 @@ class _FleetWrappingModel(Layer):
if hasattr(model, "test_step"): if hasattr(model, "test_step"):
self._test_step = model.test_step self._test_step = model.test_step
self._test_signature_fn = None self._test_signature_fn = None
elif hasattr(model, "validate_step"):
elif hasattr(model, "evaluate_step"):
self._test_step = model.validate_step self._test_step = model.validate_step
self._test_signature_fn = None self._test_signature_fn = None
else: else:
@@ -172,9 +172,9 @@ class _FleetWrappingModel(Layer):
else: else:
return self._test_step(batch) return self._test_step(batch)
elif forward_state == ForwardState.PREDICT: elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.")
else: else:
raise NotImplementedError("You should direct a concrete mode.")
raise NotImplementedError("You should direct a concrete evaluate_fn.")


class DummyGradScaler: class DummyGradScaler:
""" """


+ 40
- 61
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -4,7 +4,7 @@ import __main__
import socket import socket
import numpy as np import numpy as np
from time import sleep from time import sleep
from typing import List, Optional, Union, Dict
from typing import List, Optional, Union, Dict, Tuple, Callable
from functools import partial from functools import partial


from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
@@ -21,8 +21,6 @@ __all__ = [
from .torch_driver import TorchDriver from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import ( from fastNLP.core.drivers.torch_driver.utils import (
_DDPWrappingModel, _DDPWrappingModel,
ForwardState,
_MODE_PARAMETER,
reset_seed, reset_seed,
replace_sampler, replace_sampler,
replace_batch_sampler replace_batch_sampler
@@ -158,10 +156,10 @@ class TorchDDPDriver(TorchDriver):
———————————————————————————————————————————————————————————————————————————————————————————————————————— ————————————————————————————————————————————————————————————————————————————————————————————————————————


3. _DDPWrappingModel 的作用; 3. _DDPWrappingModel 的作用;
因为我们即需要调用模型的 `train_step`、`validate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel`
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的
forward 函数,还是 `train_step`、`validate_step`、`test_step` 方法。
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。


4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; 4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理;


@@ -204,37 +202,6 @@ class TorchDDPDriver(TorchDriver):
# 我们就直接将 model_device 置为 None; # 我们就直接将 model_device 置为 None;
self.model_device = None self.model_device = None


def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return step_fn(batch)

model = model.module
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
logger.warning(
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._test_signature_fn = model.forward

# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None) self._data_device = kwargs.get("data_device", None)
if isinstance(self._data_device, int): if isinstance(self._data_device, int):
@@ -253,7 +220,6 @@ class TorchDDPDriver(TorchDriver):
# world_size 表示的就是全局的显卡的数量; # world_size 表示的就是全局的显卡的数量;
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
self.global_rank = 0 self.global_rank = 0
self._configured = False # 防止重复调用 configure_ddp() 函数使用的


self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {})
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__)
@@ -268,8 +234,8 @@ class TorchDDPDriver(TorchDriver):
os.makedirs(name=self.output_from_new_proc, exist_ok=True) os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)


# 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_setup = False
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;


def setup(self): def setup(self):
if self._has_setup: if self._has_setup:
@@ -341,24 +307,16 @@ class TorchDDPDriver(TorchDriver):
self._pids = self.tensor_to_numeric(self._pids) self._pids = self.tensor_to_numeric(self._pids)


def configure_ddp(self): def configure_ddp(self):
if not self._configured and not isinstance(self.model, DistributedDataParallel):
if not isinstance(self.model, DistributedDataParallel):
self.model = DistributedDataParallel( self.model = DistributedDataParallel(
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
_DDPWrappingModel(self.model), device_ids=[self.model_device.index], _DDPWrappingModel(self.model), device_ids=[self.model_device.index],
**self._ddp_kwargs **self._ddp_kwargs
) )

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True
self._has_ddpwrapped = True


def open_subprocess(self): def open_subprocess(self):
if self.local_rank == 0: if self.local_rank == 0:
# self._consensus_file = Path(tempfile.mkstemp()[1])
# self._consensus_file.unlink()

# Script called as `python a/b/c.py` # Script called as `python a/b/c.py`
if __main__.__spec__ is None: # pragma: no-cover if __main__.__spec__ is None: # pragma: no-cover
# pull out the commands used to run the script and resolve the abs file path # pull out the commands used to run the script and resolve the abs file path
@@ -432,18 +390,39 @@ class TorchDDPDriver(TorchDriver):
return self._data_device return self._data_device
return self.model_device return self.model_device


def train_step(self, batch):
# 注意这里的 self.model 已经是 'fastNLP.drivers.utils._DDPWrappingModel';
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TRAIN})
return self._train_step(batch)

def validate_step(self, batch):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.VALIDATE})
return self._validate_step(batch)

def test_step(self, batch):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if self._has_ddpwrapped:
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
wo_auto_param_call=self.wo_auto_param_call)
else:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
model = self.unwrap_model()
if self._has_ddpwrapped:
if hasattr(model, fn):
fn = getattr(model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return model, model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
"`DistributedDataParallel` model, which means that we will only call model.forward "
"function when we are in forward propagation.")

return self.model, model.forward


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None,
reproducible: bool = False): reproducible: bool = False):


+ 26
- 70
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -1,5 +1,5 @@
import os import os
from typing import Dict, Union
from typing import Dict, Union, Callable, Tuple, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
@@ -42,84 +42,40 @@ class TorchSingleDriver(TorchDriver):
self.global_rank = 0 self.global_rank = 0
self.world_size = 1 self.world_size = 1


if isinstance(model, DataParallel):
model = self.unwrap_model()
if hasattr(model, "train_step"):
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
# 输入的模型是 `DataParallel` 或者 `DistributedDataParallel`,我们需要保证其 signature_fn 是正确的;
model = self.unwrap_model()
self._train_signature_fn = model.forward

if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.forward

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.forward

def setup(self): def setup(self):
if self.model_device is not None: if self.model_device is not None:
self.model.to(self.model_device) self.model.to(self.model_device)


def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call: if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
return auto_param_call(fn, batch, signature_fn=signature_fn)
else: else:
return self._train_step(batch)
return fn(batch)


def validate_step(self, batch) -> Dict:
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的;
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
def get_model_call_fn(self, fn: str) -> Tuple:
if isinstance(self.model, DataParallel):
model = self.unwrap_model()
if hasattr(model, fn):
logger.warning("Notice your model is a `DataParallel` model. And your model also implements the "
f"`{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")


def test_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
f"`DataParallel` model, which means that we will only call model.forward function "
f"when we are in forward propagation.")

return self.model, model.forward
else: else:
return self._test_step(batch)
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return self.model, self.model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")


def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False): reproducible: bool = False):


+ 1
- 18
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -81,7 +81,7 @@ class TorchDriver(Driver):
self.grad_scaler.update() self.grad_scaler.update()


@staticmethod @staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train: if is_train:
if not isinstance(dataloader, DataLoader): if not isinstance(dataloader, DataLoader):
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.")
@@ -108,23 +108,6 @@ class TorchDriver(Driver):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
f"not {type(each_optimizer)}.") f"not {type(each_optimizer)}.")


def check_evaluator_mode(self, mode: str):
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"):
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for"
"'test_step'.")

@staticmethod @staticmethod
def tensor_to_numeric(tensor, reduce=None): def tensor_to_numeric(tensor, reduce=None):
if tensor is None: if tensor is None:


+ 6
- 51
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -90,14 +90,11 @@ class ForwardState(IntEnum):
PREDICT = 3 PREDICT = 3




_MODE_PARAMETER = "_forward_state"


class _DDPWrappingModel(Module): class _DDPWrappingModel(Module):
""" """
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; 该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数;
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; 之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行;
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'validate_step' 等;
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等;
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; `model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同;


@@ -109,60 +106,18 @@ class _DDPWrappingModel(Module):
super(_DDPWrappingModel, self).__init__() super(_DDPWrappingModel, self).__init__()
self.model = model self.model = model


if hasattr(model, "train_step"):
self._train_step = model.train_step
self._train_signature_fn = None
else:
self._train_step = model
self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
self._validate_step = model.validate_step
self._validate_signature_fn = None
elif hasattr(model, "test_step"):
self._validate_step = model.test_step
self._validate_signature_fn = None
else:
self._validate_step = model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
self._test_step = model.test_step
self._test_signature_fn = None
elif hasattr(model, "validate_step"):
self._test_step = model.validate_step
self._test_signature_fn = None
else:
self._test_step = model
self._test_signature_fn = model.forward

def forward(self, batch, **kwargs) -> Dict: def forward(self, batch, **kwargs) -> Dict:
""" """
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看;
""" """
forward_state = kwargs.pop(_MODE_PARAMETER)
fn = kwargs.pop("fastnlp_fn")
signature_fn = kwargs.pop("fastnlp_signature_fn")
wo_auto_param_call = kwargs.pop("wo_auto_param_call") wo_auto_param_call = kwargs.pop("wo_auto_param_call")


if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else: else:
raise NotImplementedError("You should direct a concrete mode.")
return fn(batch)




class DummyGradScaler: class DummyGradScaler:


+ 5
- 5
fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py View File

@@ -55,8 +55,8 @@ class TorchPaddleDriver(Driver):
self._train_step = self.model self._train_step = self.model
self._train_signature_fn = self.model.forward self._train_signature_fn = self.model.forward


if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None self._validate_signature_fn = None
elif hasattr(self.model, "test_step"): elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step self._validate_step = self.model.test_step
@@ -68,8 +68,8 @@ class TorchPaddleDriver(Driver):
if hasattr(self.model, "test_step"): if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step self._test_step = self.model.test_step
self._test_signature_fn = None self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.forward self._test_signature_fn = self.model.forward
else: else:
self._test_step = self.model self._test_step = self.model
@@ -81,7 +81,7 @@ class TorchPaddleDriver(Driver):
self.model.to(self.model_device) self.model.to(self.model_device)


@staticmethod @staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train: if is_train:
if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)): if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)):
raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.") raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.")


+ 3
- 3
fastNLP/core/log/logger.py View File

@@ -211,9 +211,9 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]]
raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.")


if not isinstance(mode, str): if not isinstance(mode, str):
raise TypeError("Parameter 'mode' can only be `str` type.")
raise TypeError("Parameter 'evaluate_fn' can only be `str` type.")
if mode not in {"w", "a"}: if mode not in {"w", "a"}:
raise ValueError("Parameter `mode` can only be one of these values: ('w', 'a').")
raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').")


for h in _logger.handlers: for h in _logger.handlers:
if isinstance(h, logging.FileHandler): if isinstance(h, logging.FileHandler):
@@ -230,7 +230,7 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]]
dirname = os.path.abspath(os.path.dirname(path)) dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)


# 这里只要检测到是分布式训练,我们就将 mode 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新
# 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新
# 覆盖掉原文件,而是会接着上一次的 log 继续添加; # 覆盖掉原文件,而是会接着上一次的 log 继续添加;
# 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; # 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉;
if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0: if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0:


+ 8
- 8
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2(
device=4, device=4,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -626,7 +626,7 @@ def test_trainer_checkpoint_callback_2(
train_dataloader=test_bert_dataloader_train, train_dataloader=test_bert_dataloader_train,
optimizers=test_bert_optimizers, optimizers=test_bert_optimizers,


validate_dataloaders=test_bert_dataloader_validate,
evaluate_dataloaders=test_bert_dataloader_validate,
input_mapping=bert_input_mapping, input_mapping=bert_input_mapping,
output_mapping=bert_output_mapping, output_mapping=bert_output_mapping,
metrics={"acc": acc}, metrics={"acc": acc},
@@ -700,7 +700,7 @@ def test_trainer_checkpoint_callback_2(
train_dataloader=test_bert_dataloader_train, train_dataloader=test_bert_dataloader_train,
optimizers=test_bert_optimizers, optimizers=test_bert_optimizers,


validate_dataloaders=test_bert_dataloader_validate,
evaluate_dataloaders=test_bert_dataloader_validate,
input_mapping=bert_input_mapping, input_mapping=bert_input_mapping,
output_mapping=bert_output_mapping, output_mapping=bert_output_mapping,
metrics={"acc": acc}, metrics={"acc": acc},


+ 1
- 1
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -92,7 +92,7 @@ def test_load_best_model_callback(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']},
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,


+ 1
- 1
tests/core/controllers/_test_distributed_launch_torch_1.py View File

@@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps(
device=None, device=None,
optimizers=optimizers, optimizers=optimizers,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
evaluate_dataloaders=validate_dataloaders,
metrics=metrics, metrics=metrics,


n_epochs=2, n_epochs=2,


+ 1
- 1
tests/core/controllers/_test_distributed_launch_torch_2.py View File

@@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps(
device=None, device=None,
optimizers=optimizers, optimizers=optimizers,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
evaluate_dataloaders=validate_dataloaders,
metrics=metrics, metrics=metrics,


n_epochs=2, n_epochs=2,


+ 1
- 1
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -82,7 +82,7 @@ def test_trainer_event_trigger(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,


+ 2
- 2
tests/core/controllers/test_trainer_fleet.py View File

@@ -64,8 +64,8 @@ def test_trainer_fleet(
device=device, device=device,
optimizers=optimizers, optimizers=optimizers,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
validate_every=validate_every,
evaluate_dataloaders=validate_dataloaders,
evaluate_every=validate_every,
input_mapping=None, input_mapping=None,
output_mapping=None, output_mapping=None,
metrics=metrics, metrics=metrics,


+ 2
- 2
tests/core/controllers/test_trainer_fleet_outside.py View File

@@ -70,8 +70,8 @@ def test_trainer_fleet(
device=device, device=device,
optimizers=optimizers, optimizers=optimizers,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
validate_dataloaders=validate_dataloaders,
validate_every=validate_every,
evaluate_dataloaders=validate_dataloaders,
evaluate_every=validate_every,
input_mapping=None, input_mapping=None,
output_mapping=None, output_mapping=None,
metrics=metrics, metrics=metrics,


+ 7
- 7
tests/core/controllers/test_trainer_paddle.py View File

@@ -68,13 +68,13 @@ class TrainerParameters:
# shuffle=True # shuffle=True
# ) # )
# val_dataloader = DataLoader( # val_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST(mode="test"),
# dataset=PaddleDataset_MNIST(evaluate_fn="test"),
# batch_size=MNISTTrainPaddleConfig.batch_size, # batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True # shuffle=True
# ) # )
# trainer_params.train_dataloader = train_dataloader # trainer_params.train_dataloader = train_dataloader
# trainer_params.validate_dataloaders = val_dataloader
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
# trainer_params.evaluate_dataloaders = val_dataloader
# trainer_params.evaluate_every = MNISTTrainPaddleConfig.evaluate_every
# trainer_params.metrics = {"acc": Accuracy()} # trainer_params.metrics = {"acc": Accuracy()}


# return trainer_params # return trainer_params
@@ -121,8 +121,8 @@ def test_trainer_paddle(
device=device, device=device,
optimizers=trainer_params.optimizers, optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader, train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
evaluate_dataloaders=trainer_params.validate_dataloaders,
evaluate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping, input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping, output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics, metrics=trainer_params.metrics,
@@ -139,8 +139,8 @@ def test_trainer_paddle(
device=device, device=device,
optimizers=trainer_params.optimizers, optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader, train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
evaluate_dataloaders=trainer_params.validate_dataloaders,
evaluate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping, input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping, output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics, metrics=trainer_params.metrics,


+ 8
- 8
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -98,16 +98,16 @@ def model_and_optimizers(request):




# 测试一下普通的情况; # 测试一下普通的情况;
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]])
@pytest.mark.parametrize("validate_every", [-3])
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100])
@magic_argv_env_context @magic_argv_env_context
def test_trainer_torch_with_evaluator( def test_trainer_torch_with_evaluator(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
device, device,
callbacks, callbacks,
validate_every,
evaluate_every,
n_epochs=10, n_epochs=10,
): ):
trainer = Trainer( trainer = Trainer(
@@ -116,11 +116,11 @@ def test_trainer_torch_with_evaluator(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
validate_every=validate_every,
evaluate_every=evaluate_every,


n_epochs=n_epochs, n_epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,
@@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -193,14 +193,14 @@ def test_trainer_validate_every(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,


n_epochs=n_epochs, n_epochs=n_epochs,
output_from_new_proc="all", output_from_new_proc="all",
validate_every=validate_every
evaluate_every=validate_every
) )


trainer.run() trainer.run()


+ 5
- 5
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps(
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps(


optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc(


optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
@@ -267,7 +267,7 @@ def test_trainer_on_exception(


optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
evaluate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,


+ 7
- 7
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -401,12 +401,12 @@ class TestPaddleDriverFunctions:
测试is_train参数为True时,_check_dataloader_legality函数的表现 测试is_train参数为True时,_check_dataloader_legality函数的表现
""" """
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) dataloader = paddle.io.DataLoader(PaddleNormalDataset())
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)


# batch_size 和 batch_sampler 均为 None 的情形 # batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)


# 创建torch的dataloader # 创建torch的dataloader
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
@@ -414,7 +414,7 @@ class TestPaddleDriverFunctions:
batch_size=32, shuffle=True batch_size=32, shuffle=True
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)


def test_check_dataloader_legality_in_test(self): def test_check_dataloader_legality_in_test(self):
""" """
@@ -425,7 +425,7 @@ class TestPaddleDriverFunctions:
"train": paddle.io.DataLoader(PaddleNormalDataset()), "train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset()) "test":paddle.io.DataLoader(PaddleNormalDataset())
} }
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


# batch_size 和 batch_sampler 均为 None 的情形 # batch_size 和 batch_sampler 均为 None 的情形
dataloader = { dataloader = {
@@ -433,12 +433,12 @@ class TestPaddleDriverFunctions:
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
} }
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


# 传入的不是dict,应该报错 # 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


# 创建torch的dataloader # 创建torch的dataloader
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
@@ -451,7 +451,7 @@ class TestPaddleDriverFunctions:
) )
dataloader = {"train": train_loader, "test": test_loader} dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


def test_tensor_to_numeric(self): def test_tensor_to_numeric(self):
""" """


+ 1
- 1
tests/helpers/models/torch_model.py View File

@@ -28,7 +28,7 @@ class TorchNormalModel_Classification_1(nn.Module):
x = self(x) x = self(x)
return {"loss": self.loss_fn(x, y)} return {"loss": self.loss_fn(x, y)}


def validate_step(self, x, y):
def evaluate_step(self, x, y):
""" """
如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; 如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"};
""" """


Loading…
Cancel
Save