@@ -95,7 +95,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
if self.save_topk is not None: | |||
super().on_after_trainer_initialized(trainer, driver) | |||
if self.save_topk is not None and trainer.evaluator is None: | |||
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") | |||
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||
def on_validate_end(self, trainer, results): | |||
if len(results) == 0: | |||
@@ -39,7 +39,7 @@ class Evaluator: | |||
driver: Union[str, Driver] = 'single', | |||
device: Optional[Union[int, List[int], str]] = None, | |||
batch_step_fn: Optional[callable] = None, | |||
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
@@ -58,14 +58,13 @@ class Evaluator: | |||
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | |||
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | |||
batch_step_fn 函数。 | |||
:param mode: 可选 ["validate", "test"], 当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 | |||
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, | |||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; | |||
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | |||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | |||
:param fp16: 是否使用 fp16 。 | |||
:param verbose: 是否打印 evaluate 的结果。 | |||
:param kwargs: | |||
@@ -87,9 +86,11 @@ class Evaluator: | |||
self.model = model | |||
self.metrics = metrics | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||
if dataloaders is None: | |||
raise ValueError("Parameter `dataloaders` can not be None.") | |||
self.dataloaders = dataloaders | |||
self.device = device | |||
self.verbose = verbose | |||
@@ -97,21 +98,12 @@ class Evaluator: | |||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
self.batch_step_fn = batch_step_fn | |||
self.mode = mode | |||
assert mode in {'validate', 'test'}, "Parameter `mode` should only be 'validate' or 'test'." | |||
self.input_mapping = input_mapping | |||
self.output_mapping = output_mapping | |||
if not isinstance(dataloaders, dict): | |||
dataloaders = {None: dataloaders} | |||
if mode == "validate": | |||
self._evaluate_step = self.driver.validate_step | |||
self.driver.set_dataloader(validate_dataloaders=dataloaders) | |||
else: | |||
self._evaluate_step = self.driver.test_step | |||
self.driver.set_dataloader(test_dataloaders=dataloaders) | |||
self.mode = mode | |||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | |||
self.separator = kwargs.get('separator', '#') | |||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | |||
@@ -123,10 +115,14 @@ class Evaluator: | |||
self._metric_wrapper = None | |||
_ = self.metrics_wrapper # 触发检查 | |||
assert self.driver.has_validate_dataloaders() or self.driver.has_test_dataloaders() | |||
self.driver.setup() | |||
self.driver.barrier() | |||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||
self.evaluate_fn = evaluate_fn | |||
self.dataloaders = {} | |||
for name, dl in dataloaders.items(): # 替换为正确的 sampler | |||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) | |||
@@ -136,9 +132,10 @@ class Evaluator: | |||
if self.progress_bar == 'auto': | |||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||
self.driver.check_evaluator_mode(self.mode) | |||
self.driver.barrier() | |||
self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) | |||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
""" | |||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||
@@ -156,11 +153,6 @@ class Evaluator: | |||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | |||
if self.mode == 'validate': | |||
assert self.driver.has_validate_dataloaders() | |||
else: | |||
assert self.driver.has_test_dataloaders() | |||
metric_results = {} | |||
self.reset() | |||
evaluate_context = self.driver.get_evaluate_context() | |||
@@ -235,13 +227,6 @@ class Evaluator: | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
delattr(self, '_rich_task_id') | |||
@property | |||
def eval_dataloaders(self): | |||
if self.mode == "validate": | |||
return self.driver.validate_dataloaders | |||
else: | |||
return self.driver.test_dataloaders | |||
@property | |||
def evaluate_batch_loop(self): | |||
return self._evaluate_batch_loop | |||
@@ -296,13 +281,13 @@ class Evaluator: | |||
def evaluate_step(self, batch): | |||
""" | |||
将 batch 传递到model中进行处理,根据当前 mode 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 | |||
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 | |||
返回。 | |||
:param batch: | |||
:return: | |||
""" | |||
outputs = self._evaluate_step(batch) | |||
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | |||
outputs = match_and_substitute_params(self.output_mapping, outputs) | |||
return outputs | |||
@@ -41,19 +41,20 @@ class Trainer(TrainerEventTrigger): | |||
optimizers, | |||
device: Optional[Union[int, List[int], str]] = "cpu", | |||
n_epochs: int = 20, | |||
validate_dataloaders=None, | |||
evaluate_dataloaders=None, | |||
batch_step_fn: Optional[Callable] = None, | |||
validate_batch_step_fn: Optional[Callable] = None, | |||
validate_mode: Union[str, callable] = 'validate', | |||
evaluate_batch_step_fn: Optional[Callable] = None, | |||
train_fn: Optional[str] = None, | |||
evaluate_fn: Optional[str] = None, | |||
callbacks: Union[List[Callback], Callback, None] = None, | |||
metrics: Optional[dict] = None, | |||
validate_every: Optional[Union[int, callable]] = -1, | |||
evaluate_every: Optional[Union[int, Callable]] = -1, | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
accumulation_steps: int = 1, | |||
fp16: bool = False, | |||
monitor: Union[str, callable] = None, | |||
monitor: Union[str, Callable] = None, | |||
larger_better: bool = True, | |||
marker: Optional[str] = None, | |||
**kwargs | |||
@@ -79,19 +80,19 @@ class Trainer(TrainerEventTrigger): | |||
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | |||
5. None: 为None则不对模型进行任何处理; | |||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; | |||
:param validate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||
为 None; | |||
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 | |||
`batch`;默认为 None; | |||
:param validate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||
两个参数必须为 `evaluator` 和 `batch`;默认为 None; | |||
:param validate_mode: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,其值应当为以下之一:["validate", "test"]; | |||
默认为 "validate";当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 | |||
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, | |||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用哪一个函数,例如是 `model.train_step` 还是 `model.forward`; | |||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似; | |||
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None); | |||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | |||
@@ -105,10 +106,10 @@ class Trainer(TrainerEventTrigger): | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; | |||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||
:param fp16: 是否开启混合精度训练;默认为 False; | |||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
@@ -136,10 +137,15 @@ class Trainer(TrainerEventTrigger): | |||
else: | |||
self.driver_name = driver.__class__.__name__ | |||
self.device = device | |||
if train_dataloader is None: | |||
raise ValueError("Parameter `train_dataloader` can not be None.") | |||
self.train_dataloader = train_dataloader | |||
self.evaluate_dataloaders = evaluate_dataloaders | |||
self.optimizers = optimizers | |||
self.fp16 = fp16 | |||
self.input_mapping = input_mapping | |||
self.output_mapping = output_mapping | |||
self.evaluate_fn = evaluate_fn | |||
self.batch_step_fn = batch_step_fn | |||
if batch_step_fn is not None: | |||
@@ -168,13 +174,13 @@ class Trainer(TrainerEventTrigger): | |||
optimizers=optimizers, | |||
device=device, | |||
n_epochs=n_epochs, | |||
validate_dataloaders=validate_dataloaders, | |||
validate_dataloaders=evaluate_dataloaders, | |||
batch_step_fn=batch_step_fn, | |||
validate_batch_step_fn=validate_batch_step_fn, | |||
validate_mode=validate_mode, | |||
validate_batch_step_fn=evaluate_batch_step_fn, | |||
evaluate_fn=evaluate_fn, | |||
callbacks=callbacks, | |||
metrics=metrics, | |||
validate_every=validate_every, | |||
validate_every=evaluate_every, | |||
input_mapping=input_mapping, | |||
output_mapping=output_mapping, | |||
model_wo_auto_param_call=model_wo_auto_param_call, | |||
@@ -185,9 +191,6 @@ class Trainer(TrainerEventTrigger): | |||
) | |||
self.driver.set_optimizers(optimizers=optimizers) | |||
if train_dataloader is not None: | |||
self.driver.set_dataloader(train_dataloader=train_dataloader) | |||
# 初始化 callback manager; | |||
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | |||
# 添加所有的函数式 callbacks; | |||
@@ -213,25 +216,25 @@ class Trainer(TrainerEventTrigger): | |||
_dist_sampler = None | |||
""" 设置内部的 Evaluator """ | |||
if metrics is None and validate_dataloaders is not None: | |||
if metrics is None and evaluate_dataloaders is not None: | |||
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") | |||
if metrics is not None and validate_dataloaders is None: | |||
if metrics is not None and evaluate_dataloaders is None: | |||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | |||
self.evaluator = None | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if metrics is not None and validate_dataloaders is not None: | |||
check_validate_every(validate_every) | |||
if metrics is not None and evaluate_dataloaders is not None: | |||
check_validate_every(evaluate_every) | |||
self.evaluator = Evaluator( | |||
model=model, | |||
dataloaders=validate_dataloaders, | |||
dataloaders=evaluate_dataloaders, | |||
metrics=metrics, | |||
driver=self.driver, | |||
device=device, | |||
batch_step_fn=validate_batch_step_fn, | |||
mode=validate_mode, | |||
batch_step_fn=evaluate_batch_step_fn, | |||
evaluate_fn=evaluate_fn, | |||
input_mapping=input_mapping, | |||
output_mapping=output_mapping, | |||
fp16=fp16, | |||
@@ -241,12 +244,16 @@ class Trainer(TrainerEventTrigger): | |||
) | |||
self.metrics = metrics | |||
self.validate_every = validate_every | |||
self.validate_every = evaluate_every | |||
assert self.driver.has_train_dataloader() | |||
self.driver.setup() | |||
self.driver.barrier() | |||
if train_fn is not None and not isinstance(train_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | |||
self.train_fn = train_fn | |||
self.dataloader = self.train_dataloader | |||
self.driver.set_deterministic_dataloader(self.dataloader) | |||
@@ -257,6 +264,7 @@ class Trainer(TrainerEventTrigger): | |||
self.on_after_trainer_initialized(self.driver) | |||
self.driver.barrier() | |||
self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | |||
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | |||
@@ -273,6 +281,7 @@ class Trainer(TrainerEventTrigger): | |||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | |||
:return: | |||
""" | |||
if catch_KeyboardInterrupt is None: | |||
catch_KeyboardInterrupt = not self.driver.is_distributed() | |||
else: | |||
@@ -343,7 +352,8 @@ class Trainer(TrainerEventTrigger): | |||
_validate_res: dict = validate_fn() | |||
trainer.on_validate_end(_validate_res) | |||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
if self.evaluator is not None: | |||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
def step_validate(self): | |||
""" | |||
@@ -489,11 +499,6 @@ class Trainer(TrainerEventTrigger): | |||
self.has_checked_train_batch_loop = True | |||
""" Trainer 需要的一些 property """ | |||
@property | |||
def train_dataloader(self): | |||
return self.driver.train_dataloader | |||
@property | |||
def driver(self): | |||
return self._driver | |||
@@ -684,7 +689,7 @@ class Trainer(TrainerEventTrigger): | |||
def train_step(self, batch): | |||
with self.driver.auto_cast(): | |||
outputs = self.driver.train_step(batch) | |||
outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn) | |||
outputs = match_and_substitute_params(self.output_mapping, outputs) | |||
return outputs | |||
@@ -814,6 +819,24 @@ class Trainer(TrainerEventTrigger): | |||
def data_device(self): | |||
return self.driver.data_device | |||
""" dataloader property """ | |||
@property | |||
def train_dataloader(self): | |||
return self._train_dataloader | |||
@train_dataloader.setter | |||
def train_dataloader(self, train_dataloader): | |||
self._train_dataloader = train_dataloader | |||
@property | |||
def evaluate_dataloaders(self): | |||
return self._evaluate_dataloaders | |||
@evaluate_dataloaders.setter | |||
def evaluate_dataloaders(self, evaluate_dataloaders): | |||
self._evaluate_dataloaders = evaluate_dataloaders | |||
@@ -128,6 +128,6 @@ class _TruncatedDataLoader: | |||
def check_validate_every(validate_every): | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | |||
if callable(validate_every): | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) |
@@ -1,7 +1,7 @@ | |||
import os | |||
import signal | |||
import sys | |||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union | |||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple | |||
from abc import ABC, abstractmethod | |||
from datetime import datetime | |||
from pathlib import Path | |||
@@ -79,41 +79,44 @@ class Driver(ABC): | |||
""" | |||
@abstractmethod | |||
def train_step(self, batch): | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
""" | |||
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; | |||
如果检测到用户模型实现了 train_step | |||
通过调用 `fn` 来实现训练时的前向传播过程; | |||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||
函数; | |||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
:return: 返回由模型的 `train_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
:param fn: 由 Trainer 传入的用于网络前向传播一次的函数; | |||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `train_step` function.") | |||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | |||
def validate_step(self, batch): | |||
@abstractmethod | |||
def get_model_call_fn(self, fn: str) -> Tuple: | |||
""" | |||
通过调用模型自带的 `validate_step` 或者 `forward` 方法来实现模型评测的前向过程; | |||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
:return: 返回由模型的 `validate_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `validate_step` function.") | |||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
def test_step(self, batch): | |||
""" | |||
通过调用模型自带的 `test_step` 或者 `forward` 方法来实现模型评测的前向过程; | |||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||
函数,然后给出 warning; | |||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
:return: 返回由模型的 `test_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `test_step` function.") | |||
def check_evaluator_mode(self, mode: str): | |||
r""" | |||
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||
我们应当提醒用户这一行为; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.") | |||
raise NotImplementedError("Each specific driver should implemented its own `get_model_call_fn` function.") | |||
@property | |||
def model(self): | |||
@@ -123,59 +126,8 @@ class Driver(ABC): | |||
def model(self, model): | |||
self._model = model | |||
@property | |||
def train_dataloader(self): | |||
return self._train_dataloader | |||
@train_dataloader.setter | |||
def train_dataloader(self, train_dataloader: Any): | |||
self._train_dataloader = train_dataloader | |||
@property | |||
def validate_dataloaders(self): | |||
return self._validate_dataloaders | |||
@validate_dataloaders.setter | |||
def validate_dataloaders(self, validate_dataloaders: Any): | |||
self._validate_dataloaders = validate_dataloaders | |||
@property | |||
def test_dataloaders(self): | |||
return self._test_dataloaders | |||
@test_dataloaders.setter | |||
def test_dataloaders(self, test_dataloaders: Any): | |||
self._test_dataloaders = test_dataloaders | |||
@property | |||
def predict_dataloaders(self): | |||
return self._predict_dataloaders | |||
@predict_dataloaders.setter | |||
def predict_dataloaders(self, predict_dataloaders: Any): | |||
self._predict_dataloaders = predict_dataloaders | |||
def set_dataloader(self, **kwargs): | |||
r""" | |||
设置训练或者检验过程中的数据;用于在 trainer 和 evaluator 中将数据 dataloader 挂载到每一个具体的 driver 上; | |||
:param kwargs: 输入的数据,应当使用 'keyword-only' 的参数进行设置; | |||
""" | |||
if "train_dataloader" in kwargs: | |||
self.train_dataloader = kwargs["train_dataloader"] | |||
self._check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||
if "validate_dataloaders" in kwargs: | |||
self.validate_dataloaders = kwargs["validate_dataloaders"] | |||
self._check_dataloader_legality(self.validate_dataloaders, "validate_dataloaders", is_train=False) | |||
if "test_dataloaders" in kwargs: | |||
self.test_dataloaders = kwargs["test_dataloaders"] | |||
self._check_dataloader_legality(self.test_dataloaders, "test_dataloaders", is_train=False) | |||
if "predict_dataloaders" in kwargs: | |||
self.predict_dataloaders = kwargs["predict_dataloaders"] | |||
self._check_dataloader_legality(self.predict_dataloaders, "predict_dataloaders", is_train=False) | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
r""" | |||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 | |||
行为是不相同的; | |||
@@ -183,19 +135,7 @@ class Driver(ABC): | |||
:param dataloader: 需要检测的输入的 `dataloader`; | |||
:param dataloader_name: | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `_check_dataloader_legality` function.") | |||
def has_train_dataloader(self): | |||
return "_train_dataloader" in self.__dict__ | |||
def has_validate_dataloaders(self): | |||
return "_validate_dataloaders" in self.__dict__ | |||
def has_test_dataloaders(self): | |||
return "_test_dataloaders" in self.__dict__ | |||
def has_predict_dataloaders(self): | |||
return "_predict_dataloaders" in self.__dict__ | |||
raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") | |||
@property | |||
def optimizers(self) -> List: | |||
@@ -39,7 +39,7 @@ class JittorDriver(Driver): | |||
self.grad_scaler = _grad_scaler() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
# 在fastnlp中实现了JittorDataLoader | |||
# TODO: 是否允许传入Dataset? | |||
if is_train: | |||
@@ -64,18 +64,18 @@ class JittorDriver(Driver): | |||
def check_evaluator_mode(self, mode: str): | |||
model = self.unwrap_model() | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if not hasattr(model, "evaluate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning_once( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for" | |||
"'evaluate_step'.") | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
if hasattr(model, "evaluate_step"): | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for" | |||
"'test_step'.") | |||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | |||
@@ -35,8 +35,8 @@ class JittorSingleDriver(JittorDriver): | |||
model = self.unwrap_model() | |||
self._train_signature_fn = model.execute | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
if hasattr(self.model, "evaluate_step"): | |||
self._validate_step = self.model.evaluate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
@@ -49,9 +49,9 @@ class JittorSingleDriver(JittorDriver): | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.validate_step | |||
elif hasattr(self.model, "evaluate_step"): | |||
self._test_step = self.model.evaluate_step | |||
self._test_signature_fn = self.model.evaluate_step | |||
else: | |||
self._test_step = self.model | |||
model = self.unwrap_model() | |||
@@ -118,11 +118,11 @@ class PaddleFleetDriver(PaddleDriver): | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
if hasattr(model, "validate_step"): | |||
if hasattr(model, "evaluate_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
"model also implements the `evaluate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `evaluate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
if hasattr(model, "test_step"): | |||
@@ -72,7 +72,7 @@ class PaddleDriver(Driver): | |||
optimizer.clear_grad() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
r""" | |||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 | |||
要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 | |||
@@ -117,24 +117,24 @@ class PaddleDriver(Driver): | |||
def check_evaluator_mode(self, mode: str): | |||
r""" | |||
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么 | |||
我们应当提醒用户这一行为; | |||
""" | |||
model = self.unwrap_model() | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if not hasattr(model, "evaluate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
"'evaluate_step'.") | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
if hasattr(model, "evaluate_step"): | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | |||
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for" | |||
"'test_step'.") | |||
@staticmethod | |||
@@ -50,10 +50,10 @@ class PaddleSingleDriver(PaddleDriver): | |||
self._train_step = self.model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
if hasattr(model, "evaluate_step"): | |||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||
"implements the `validate_step` method, which we can not call actually, we " | |||
"will call `forward` function instead of `validate_step` and you should note that.") | |||
"implements the `evaluate_step` method, which we can not call actually, we " | |||
"will call `forward` function instead of `evaluate_step` and you should note that.") | |||
self._validate_step = self.model | |||
self._validate_signature_fn = model.forward | |||
@@ -73,8 +73,8 @@ class PaddleSingleDriver(PaddleDriver): | |||
model = self.unwrap_model() | |||
self._train_signature_fn = model.forward | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
if hasattr(self.model, "evaluate_step"): | |||
self._validate_step = self.model.evaluate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
@@ -87,9 +87,9 @@ class PaddleSingleDriver(PaddleDriver): | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.validate_step | |||
elif hasattr(self.model, "evaluate_step"): | |||
self._test_step = self.model.evaluate_step | |||
self._test_signature_fn = self.model.evaluate_step | |||
else: | |||
self._test_step = self.model | |||
model = self.unwrap_model() | |||
@@ -108,11 +108,11 @@ class _FleetWrappingModel(Layer): | |||
self._train_step = self.model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
if hasattr(model, "evaluate_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
"model also implements the `evaluate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `evaluate_step` and you should note that.") | |||
self._validate_step = self.model | |||
self._validate_signature_fn = model.forward | |||
@@ -131,7 +131,7 @@ class _FleetWrappingModel(Layer): | |||
self._train_step = model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
if hasattr(model, "evaluate_step"): | |||
self._validate_step = model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(model, "test_step"): | |||
@@ -144,7 +144,7 @@ class _FleetWrappingModel(Layer): | |||
if hasattr(model, "test_step"): | |||
self._test_step = model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(model, "validate_step"): | |||
elif hasattr(model, "evaluate_step"): | |||
self._test_step = model.validate_step | |||
self._test_signature_fn = None | |||
else: | |||
@@ -172,9 +172,9 @@ class _FleetWrappingModel(Layer): | |||
else: | |||
return self._test_step(batch) | |||
elif forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.") | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||
raise NotImplementedError("You should direct a concrete evaluate_fn.") | |||
class DummyGradScaler: | |||
""" | |||
@@ -4,7 +4,7 @@ import __main__ | |||
import socket | |||
import numpy as np | |||
from time import sleep | |||
from typing import List, Optional, Union, Dict | |||
from typing import List, Optional, Union, Dict, Tuple, Callable | |||
from functools import partial | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
@@ -21,8 +21,6 @@ __all__ = [ | |||
from .torch_driver import TorchDriver | |||
from fastNLP.core.drivers.torch_driver.utils import ( | |||
_DDPWrappingModel, | |||
ForwardState, | |||
_MODE_PARAMETER, | |||
reset_seed, | |||
replace_sampler, | |||
replace_batch_sampler | |||
@@ -158,10 +156,10 @@ class TorchDDPDriver(TorchDriver): | |||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||
3. _DDPWrappingModel 的作用; | |||
因为我们即需要调用模型的 `train_step`、`validate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | |||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | |||
forward 函数,还是 `train_step`、`validate_step`、`test_step` 方法。 | |||
forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 | |||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | |||
@@ -204,37 +202,6 @@ class TorchDDPDriver(TorchDriver): | |||
# 我们就直接将 model_device 置为 None; | |||
self.model_device = None | |||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||
else: | |||
return step_fn(batch) | |||
model = model.module | |||
if hasattr(model, "train_step"): | |||
logger.warning( | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
logger.warning( | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._test_signature_fn = model.forward | |||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("data_device", None) | |||
if isinstance(self._data_device, int): | |||
@@ -253,7 +220,6 @@ class TorchDDPDriver(TorchDriver): | |||
# world_size 表示的就是全局的显卡的数量; | |||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | |||
self.global_rank = 0 | |||
self._configured = False # 防止重复调用 configure_ddp() 函数使用的 | |||
self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) | |||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | |||
@@ -268,8 +234,8 @@ class TorchDDPDriver(TorchDriver): | |||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||
# 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||
self._has_setup = False | |||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | |||
def setup(self): | |||
if self._has_setup: | |||
@@ -341,24 +307,16 @@ class TorchDDPDriver(TorchDriver): | |||
self._pids = self.tensor_to_numeric(self._pids) | |||
def configure_ddp(self): | |||
if not self._configured and not isinstance(self.model, DistributedDataParallel): | |||
if not isinstance(self.model, DistributedDataParallel): | |||
self.model = DistributedDataParallel( | |||
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | |||
_DDPWrappingModel(self.model), device_ids=[self.model_device.index], | |||
**self._ddp_kwargs | |||
) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._configured = True | |||
self._has_ddpwrapped = True | |||
def open_subprocess(self): | |||
if self.local_rank == 0: | |||
# self._consensus_file = Path(tempfile.mkstemp()[1]) | |||
# self._consensus_file.unlink() | |||
# Script called as `python a/b/c.py` | |||
if __main__.__spec__ is None: # pragma: no-cover | |||
# pull out the commands used to run the script and resolve the abs file path | |||
@@ -432,18 +390,39 @@ class TorchDDPDriver(TorchDriver): | |||
return self._data_device | |||
return self.model_device | |||
def train_step(self, batch): | |||
# 注意这里的 self.model 已经是 'fastNLP.drivers.utils._DDPWrappingModel'; | |||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||
return self._train_step(batch) | |||
def validate_step(self, batch): | |||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
return self._test_step(batch) | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
if self._has_ddpwrapped: | |||
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | |||
wo_auto_param_call=self.wo_auto_param_call) | |||
else: | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||
else: | |||
return fn(batch) | |||
def get_model_call_fn(self, fn: str) -> Tuple: | |||
model = self.unwrap_model() | |||
if self._has_ddpwrapped: | |||
if hasattr(model, fn): | |||
fn = getattr(model, fn) | |||
if not callable(fn): | |||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||
return fn, None | |||
elif fn in {"train_step", "evaluate_step"}: | |||
return model, model.forward | |||
else: | |||
raise RuntimeError(f"There is no `{fn}` method in your model.") | |||
else: | |||
if hasattr(model, fn): | |||
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " | |||
f"the `{fn}` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
elif fn not in {"train_step", "evaluate_step"}: | |||
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " | |||
"`DistributedDataParallel` model, which means that we will only call model.forward " | |||
"function when we are in forward propagation.") | |||
return self.model, model.forward | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||
reproducible: bool = False): | |||
@@ -1,5 +1,5 @@ | |||
import os | |||
from typing import Dict, Union | |||
from typing import Dict, Union, Callable, Tuple, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
@@ -42,84 +42,40 @@ class TorchSingleDriver(TorchDriver): | |||
self.global_rank = 0 | |||
self.world_size = 1 | |||
if isinstance(model, DataParallel): | |||
model = self.unwrap_model() | |||
if hasattr(model, "train_step"): | |||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = self.model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = self.model | |||
self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = self.model | |||
self._test_signature_fn = model.forward | |||
else: | |||
if hasattr(self.model, "train_step"): | |||
self._train_step = self.model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = self.model | |||
# 输入的模型是 `DataParallel` 或者 `DistributedDataParallel`,我们需要保证其 signature_fn 是正确的; | |||
model = self.unwrap_model() | |||
self._train_signature_fn = model.forward | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
self._validate_signature_fn = self.model.test_step | |||
else: | |||
self._validate_step = self.model | |||
model = self.unwrap_model() | |||
self._validate_signature_fn = model.forward | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.validate_step | |||
else: | |||
self._test_step = self.model | |||
model = self.unwrap_model() | |||
self._test_signature_fn = model.forward | |||
def setup(self): | |||
if self.model_device is not None: | |||
self.model.to(self.model_device) | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
return fn(batch) | |||
def validate_step(self, batch) -> Dict: | |||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | |||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def get_model_call_fn(self, fn: str) -> Tuple: | |||
if isinstance(self.model, DataParallel): | |||
model = self.unwrap_model() | |||
if hasattr(model, fn): | |||
logger.warning("Notice your model is a `DataParallel` model. And your model also implements the " | |||
f"`{fn}` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
def test_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
elif fn not in {"train_step", "evaluate_step"}: | |||
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " | |||
f"`DataParallel` model, which means that we will only call model.forward function " | |||
f"when we are in forward propagation.") | |||
return self.model, model.forward | |||
else: | |||
return self._test_step(batch) | |||
if hasattr(self.model, fn): | |||
fn = getattr(self.model, fn) | |||
if not callable(fn): | |||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||
return fn, None | |||
elif fn in {"train_step", "evaluate_step"}: | |||
return self.model, self.model.forward | |||
else: | |||
raise RuntimeError(f"There is no `{fn}` method in your model.") | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
@@ -81,7 +81,7 @@ class TorchDriver(Driver): | |||
self.grad_scaler.update() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
if not isinstance(dataloader, DataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") | |||
@@ -108,23 +108,6 @@ class TorchDriver(Driver): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def check_evaluator_mode(self, mode: str): | |||
model = self.unwrap_model() | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning_once( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
if tensor is None: | |||
@@ -90,14 +90,11 @@ class ForwardState(IntEnum): | |||
PREDICT = 3 | |||
_MODE_PARAMETER = "_forward_state" | |||
class _DDPWrappingModel(Module): | |||
""" | |||
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; | |||
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; | |||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'validate_step' 等; | |||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等; | |||
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 | |||
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; | |||
@@ -109,60 +106,18 @@ class _DDPWrappingModel(Module): | |||
super(_DDPWrappingModel, self).__init__() | |||
self.model = model | |||
if hasattr(model, "train_step"): | |||
self._train_step = model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
self._validate_step = model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(model, "test_step"): | |||
self._validate_step = model.test_step | |||
self._validate_signature_fn = None | |||
else: | |||
self._validate_step = model | |||
self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
self._test_step = model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(model, "validate_step"): | |||
self._test_step = model.validate_step | |||
self._test_signature_fn = None | |||
else: | |||
self._test_step = model | |||
self._test_signature_fn = model.forward | |||
def forward(self, batch, **kwargs) -> Dict: | |||
""" | |||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | |||
""" | |||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||
fn = kwargs.pop("fastnlp_fn") | |||
signature_fn = kwargs.pop("fastnlp_signature_fn") | |||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||
if forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
elif forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
elif forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
elif forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||
return fn(batch) | |||
class DummyGradScaler: | |||
@@ -55,8 +55,8 @@ class TorchPaddleDriver(Driver): | |||
self._train_step = self.model | |||
self._train_signature_fn = self.model.forward | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
if hasattr(self.model, "evaluate_step"): | |||
self._validate_step = self.model.evaluate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
@@ -68,8 +68,8 @@ class TorchPaddleDriver(Driver): | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
elif hasattr(self.model, "evaluate_step"): | |||
self._test_step = self.model.evaluate_step | |||
self._test_signature_fn = self.model.forward | |||
else: | |||
self._test_step = self.model | |||
@@ -81,7 +81,7 @@ class TorchPaddleDriver(Driver): | |||
self.model.to(self.model_device) | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.") | |||
@@ -211,9 +211,9 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] | |||
raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") | |||
if not isinstance(mode, str): | |||
raise TypeError("Parameter 'mode' can only be `str` type.") | |||
raise TypeError("Parameter 'evaluate_fn' can only be `str` type.") | |||
if mode not in {"w", "a"}: | |||
raise ValueError("Parameter `mode` can only be one of these values: ('w', 'a').") | |||
raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').") | |||
for h in _logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
@@ -230,7 +230,7 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
# 这里只要检测到是分布式训练,我们就将 mode 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 | |||
# 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 | |||
# 覆盖掉原文件,而是会接着上一次的 log 继续添加; | |||
# 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; | |||
if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0: | |||
@@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2( | |||
device=4, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -626,7 +626,7 @@ def test_trainer_checkpoint_callback_2( | |||
train_dataloader=test_bert_dataloader_train, | |||
optimizers=test_bert_optimizers, | |||
validate_dataloaders=test_bert_dataloader_validate, | |||
evaluate_dataloaders=test_bert_dataloader_validate, | |||
input_mapping=bert_input_mapping, | |||
output_mapping=bert_output_mapping, | |||
metrics={"acc": acc}, | |||
@@ -700,7 +700,7 @@ def test_trainer_checkpoint_callback_2( | |||
train_dataloader=test_bert_dataloader_train, | |||
optimizers=test_bert_optimizers, | |||
validate_dataloaders=test_bert_dataloader_validate, | |||
evaluate_dataloaders=test_bert_dataloader_validate, | |||
input_mapping=bert_input_mapping, | |||
output_mapping=bert_output_mapping, | |||
metrics={"acc": acc}, | |||
@@ -92,7 +92,7 @@ def test_load_best_model_callback( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
device=None, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
validate_dataloaders=validate_dataloaders, | |||
evaluate_dataloaders=validate_dataloaders, | |||
metrics=metrics, | |||
n_epochs=2, | |||
@@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
device=None, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
validate_dataloaders=validate_dataloaders, | |||
evaluate_dataloaders=validate_dataloaders, | |||
metrics=metrics, | |||
n_epochs=2, | |||
@@ -82,7 +82,7 @@ def test_trainer_event_trigger( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -64,8 +64,8 @@ def test_trainer_fleet( | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
validate_dataloaders=validate_dataloaders, | |||
validate_every=validate_every, | |||
evaluate_dataloaders=validate_dataloaders, | |||
evaluate_every=validate_every, | |||
input_mapping=None, | |||
output_mapping=None, | |||
metrics=metrics, | |||
@@ -70,8 +70,8 @@ def test_trainer_fleet( | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
validate_dataloaders=validate_dataloaders, | |||
validate_every=validate_every, | |||
evaluate_dataloaders=validate_dataloaders, | |||
evaluate_every=validate_every, | |||
input_mapping=None, | |||
output_mapping=None, | |||
metrics=metrics, | |||
@@ -68,13 +68,13 @@ class TrainerParameters: | |||
# shuffle=True | |||
# ) | |||
# val_dataloader = DataLoader( | |||
# dataset=PaddleDataset_MNIST(mode="test"), | |||
# dataset=PaddleDataset_MNIST(evaluate_fn="test"), | |||
# batch_size=MNISTTrainPaddleConfig.batch_size, | |||
# shuffle=True | |||
# ) | |||
# trainer_params.train_dataloader = train_dataloader | |||
# trainer_params.validate_dataloaders = val_dataloader | |||
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||
# trainer_params.evaluate_dataloaders = val_dataloader | |||
# trainer_params.evaluate_every = MNISTTrainPaddleConfig.evaluate_every | |||
# trainer_params.metrics = {"acc": Accuracy()} | |||
# return trainer_params | |||
@@ -121,8 +121,8 @@ def test_trainer_paddle( | |||
device=device, | |||
optimizers=trainer_params.optimizers, | |||
train_dataloader=trainer_params.train_dataloader, | |||
validate_dataloaders=trainer_params.validate_dataloaders, | |||
validate_every=trainer_params.validate_every, | |||
evaluate_dataloaders=trainer_params.validate_dataloaders, | |||
evaluate_every=trainer_params.validate_every, | |||
input_mapping=trainer_params.input_mapping, | |||
output_mapping=trainer_params.output_mapping, | |||
metrics=trainer_params.metrics, | |||
@@ -139,8 +139,8 @@ def test_trainer_paddle( | |||
device=device, | |||
optimizers=trainer_params.optimizers, | |||
train_dataloader=trainer_params.train_dataloader, | |||
validate_dataloaders=trainer_params.validate_dataloaders, | |||
validate_every=trainer_params.validate_every, | |||
evaluate_dataloaders=trainer_params.validate_dataloaders, | |||
evaluate_every=trainer_params.validate_every, | |||
input_mapping=trainer_params.input_mapping, | |||
output_mapping=trainer_params.output_mapping, | |||
metrics=trainer_params.metrics, | |||
@@ -98,16 +98,16 @@ def model_and_optimizers(request): | |||
# 测试一下普通的情况; | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | |||
@pytest.mark.parametrize("validate_every", [-3]) | |||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | |||
@magic_argv_env_context | |||
def test_trainer_torch_with_evaluator( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
callbacks, | |||
validate_every, | |||
evaluate_every, | |||
n_epochs=10, | |||
): | |||
trainer = Trainer( | |||
@@ -116,11 +116,11 @@ def test_trainer_torch_with_evaluator( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
validate_every=validate_every, | |||
evaluate_every=evaluate_every, | |||
n_epochs=n_epochs, | |||
callbacks=callbacks, | |||
@@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -193,14 +193,14 @@ def test_trainer_validate_every( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
output_from_new_proc="all", | |||
validate_every=validate_every | |||
evaluate_every=validate_every | |||
) | |||
trainer.run() | |||
@@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc( | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -267,7 +267,7 @@ def test_trainer_on_exception( | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -401,12 +401,12 @@ class TestPaddleDriverFunctions: | |||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||
""" | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
# 创建torch的dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
@@ -414,7 +414,7 @@ class TestPaddleDriverFunctions: | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
def test_check_dataloader_legality_in_test(self): | |||
""" | |||
@@ -425,7 +425,7 @@ class TestPaddleDriverFunctions: | |||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | |||
} | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = { | |||
@@ -433,12 +433,12 @@ class TestPaddleDriverFunctions: | |||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||
} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
# 传入的不是dict,应该报错 | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
# 创建torch的dataloader | |||
train_loader = torch.utils.data.DataLoader( | |||
@@ -451,7 +451,7 @@ class TestPaddleDriverFunctions: | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
def test_tensor_to_numeric(self): | |||
""" | |||
@@ -28,7 +28,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||
x = self(x) | |||
return {"loss": self.loss_fn(x, y)} | |||
def validate_step(self, x, y): | |||
def evaluate_step(self, x, y): | |||
""" | |||
如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; | |||
""" | |||