@@ -49,8 +49,8 @@ class Evaluator: | |||||
): | ): | ||||
""" | """ | ||||
:param model: | |||||
:param dataloaders: | |||||
:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 | |||||
:param dataloaders: 待评测的数据集。 | |||||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | ||||
metric ,torchmetrics,allennlpmetrics等。 | metric ,torchmetrics,allennlpmetrics等。 | ||||
:param driver: 使用 driver 。 | :param driver: 使用 driver 。 | ||||
@@ -119,7 +119,7 @@ class Evaluator: | |||||
self.driver.barrier() | self.driver.barrier() | ||||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | if evaluate_fn is not None and not isinstance(evaluate_fn, str): | ||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | |||||
self._evaluate_step, self._evaluate_step_signature_fn = \ | self._evaluate_step, self._evaluate_step_signature_fn = \ | ||||
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | ||||
self.evaluate_fn = evaluate_fn | self.evaluate_fn = evaluate_fn | ||||
@@ -86,10 +86,12 @@ class Trainer(TrainerEventTrigger): | |||||
`batch`;默认为 None; | `batch`;默认为 None; | ||||
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | :param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | ||||
两个参数必须为 `evaluator` 和 `batch`;默认为 None; | 两个参数必须为 `evaluator` 和 `batch`;默认为 None; | ||||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用哪一个函数,例如是 `model.train_step` 还是 `model.forward`; | |||||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; | |||||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, | |||||
则使用模型默认的前向传播函数。 | |||||
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似; | :param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似; | ||||
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None); | |||||
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);如果该值为 None ,将首先尝试寻找模型中是否有 | |||||
evaluate_step 这个函数,如果没有则使用 forward 函数。 | |||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | ||||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | ||||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | ||||
@@ -5,6 +5,7 @@ if _NEED_IMPORT_TORCH: | |||||
import torch | import torch | ||||
from torch.nn import DataParallel | from torch.nn import DataParallel | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchSingleDriver' | 'TorchSingleDriver' | ||||
@@ -13,7 +14,9 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | ||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -71,11 +74,13 @@ class TorchSingleDriver(TorchDriver): | |||||
fn = getattr(self.model, fn) | fn = getattr(self.model, fn) | ||||
if not callable(fn): | if not callable(fn): | ||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | ||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | |||||
return fn, None | return fn, None | ||||
elif fn in {"train_step", "evaluate_step"}: | elif fn in {"train_step", "evaluate_step"}: | ||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
return self.model, self.model.forward | return self.model, self.model.forward | ||||
else: | else: | ||||
raise RuntimeError(f"There is no `{fn}` method in your model.") | |||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
@@ -96,12 +101,18 @@ class TorchSingleDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
if isinstance(args.sampler, TorchRandomSampler): | |||||
# 如果本来就是随机的,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -164,7 +164,7 @@ def _get_keys(args:List[Dict]) -> List[List[str]]: | |||||
return _provided_keys | return _provided_keys | ||||
def _get_fun_msg(fn)->str: | |||||
def _get_fun_msg(fn, with_fp=True)->str: | |||||
""" | """ | ||||
获取函数的基本信息,帮助报错。 | 获取函数的基本信息,帮助报错。 | ||||
ex: | ex: | ||||
@@ -172,6 +172,7 @@ def _get_fun_msg(fn)->str: | |||||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | # `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | ||||
:param callable fn: | :param callable fn: | ||||
:param with_fp: 是否包含函数所在的文件信息。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(fn, functools.partial): | if isinstance(fn, functools.partial): | ||||
@@ -180,9 +181,12 @@ def _get_fun_msg(fn)->str: | |||||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | fn_name = fn.__qualname__ + str(inspect.signature(fn)) | ||||
except: | except: | ||||
fn_name = str(fn) | fn_name = str(fn) | ||||
try: | |||||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||||
except: | |||||
if with_fp: | |||||
try: | |||||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||||
except: | |||||
fp = '' | |||||
else: | |||||
fp = '' | fp = '' | ||||
msg = f'`{fn_name}`' + fp | msg = f'`{fn_name}`' + fp | ||||
return msg | return msg | ||||
@@ -37,7 +37,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -63,7 +63,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2( | |||||
device=4, | device=4, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -473,12 +473,12 @@ def test_trainer_checkpoint_callback_1( | |||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | ||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.skip("Skip transformers test for now.") | |||||
def test_trainer_checkpoint_callback_2( | def test_trainer_checkpoint_callback_2( | ||||
driver, | driver, | ||||
device, | device, | ||||
version | version | ||||
): | ): | ||||
pytest.skip("Skip transformers test for now.") | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | path = Path.cwd().joinpath(f"test_model_checkpoint") | ||||
path.mkdir(exist_ok=True, parents=True) | path.mkdir(exist_ok=True, parents=True) | ||||
@@ -40,7 +40,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -66,7 +66,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -92,7 +92,7 @@ def test_load_best_model_callback( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -105,7 +105,7 @@ def test_load_best_model_callback( | |||||
driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | ||||
evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | ||||
dataloaders={'dl1': model_and_optimizers.validate_dataloaders}, | |||||
dataloaders={'dl1': model_and_optimizers.evaluate_dataloaders}, | |||||
metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | ||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | ||||
progress_bar='rich', use_dist_sampler=False) | progress_bar='rich', use_dist_sampler=False) | ||||
@@ -75,7 +75,7 @@ _dataloader = DataLoader( | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
train_dataloader = _dataloader | train_dataloader = _dataloader | ||||
validate_dataloaders = _dataloader | |||||
evaluate_dataloaders = _dataloader | |||||
metrics = {"acc": Accuracy()} | metrics = {"acc": Accuracy()} | ||||
@@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
evaluate_dataloaders=validate_dataloaders, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
metrics=metrics, | metrics=metrics, | ||||
n_epochs=2, | n_epochs=2, | ||||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||||
import argparse | import argparse | ||||
import os | import os | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||||
import sys | import sys | ||||
path = os.path.abspath(__file__) | path = os.path.abspath(__file__) | ||||
@@ -63,7 +63,7 @@ _dataloader = DataLoader( | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
train_dataloader = _dataloader | train_dataloader = _dataloader | ||||
validate_dataloaders = _dataloader | |||||
evaluate_dataloaders = _dataloader | |||||
metrics = {"acc": Accuracy()} | metrics = {"acc": Accuracy()} | ||||
@@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
evaluate_dataloaders=validate_dataloaders, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
metrics=metrics, | metrics=metrics, | ||||
n_epochs=2, | n_epochs=2, | ||||
@@ -30,7 +30,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -57,7 +57,7 @@ def model_and_optimizers(): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -82,7 +82,7 @@ def test_trainer_event_trigger( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -43,7 +43,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -71,7 +71,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
elif request.param == 1: | elif request.param == 1: | ||||
@@ -91,7 +91,7 @@ def model_and_optimizers(request): | |||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.train_dataloader = _dataloader | trainer_params.train_dataloader = _dataloader | ||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | trainer_params.metrics = {"acc": Accuracy()} | ||||
return trainer_params | return trainer_params | ||||
@@ -116,7 +116,7 @@ def test_trainer_torch_with_evaluator( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -193,7 +193,7 @@ def test_trainer_validate_every( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -38,7 +38,7 @@ class TrainerParameters: | |||||
model: Any = None | model: Any = None | ||||
optimizers: Any = None | optimizers: Any = None | ||||
train_dataloader: Any = None | train_dataloader: Any = None | ||||
validate_dataloaders: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | input_mapping: Any = None | ||||
output_mapping: Any = None | output_mapping: Any = None | ||||
metrics: Any = None | metrics: Any = None | ||||
@@ -65,7 +65,7 @@ def model_and_optimizers(request): | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | batch_size=NormalClassificationTrainTorchConfig.batch_size, | ||||
shuffle=True | shuffle=True | ||||
) | ) | ||||
trainer_params.validate_dataloaders = None | |||||
trainer_params.evaluate_dataloaders = None | |||||
trainer_params.input_mapping = None | trainer_params.input_mapping = None | ||||
trainer_params.output_mapping = None | trainer_params.output_mapping = None | ||||
@@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc( | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
@@ -267,7 +267,7 @@ def test_trainer_on_exception( | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||