@@ -49,8 +49,8 @@ class Evaluator: | |||
): | |||
""" | |||
:param model: | |||
:param dataloaders: | |||
:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 | |||
:param dataloaders: 待评测的数据集。 | |||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | |||
metric ,torchmetrics,allennlpmetrics等。 | |||
:param driver: 使用 driver 。 | |||
@@ -119,7 +119,7 @@ class Evaluator: | |||
self.driver.barrier() | |||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | |||
self._evaluate_step, self._evaluate_step_signature_fn = \ | |||
self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||
self.evaluate_fn = evaluate_fn | |||
@@ -86,10 +86,12 @@ class Trainer(TrainerEventTrigger): | |||
`batch`;默认为 None; | |||
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||
两个参数必须为 `evaluator` 和 `batch`;默认为 None; | |||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用哪一个函数,例如是 `model.train_step` 还是 `model.forward`; | |||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; | |||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, | |||
则使用模型默认的前向传播函数。 | |||
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似; | |||
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None); | |||
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);如果该值为 None ,将首先尝试寻找模型中是否有 | |||
evaluate_step 这个函数,如果没有则使用 forward 函数。 | |||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
@@ -5,6 +5,7 @@ if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.nn import DataParallel | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
__all__ = [ | |||
'TorchSingleDriver' | |||
@@ -13,7 +14,9 @@ __all__ = [ | |||
from .torch_driver import TorchDriver | |||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.utils.utils import _get_fun_msg | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||
from fastNLP.core.samplers import RandomSampler | |||
from fastNLP.core.log import logger | |||
@@ -71,11 +74,13 @@ class TorchSingleDriver(TorchDriver): | |||
fn = getattr(self.model, fn) | |||
if not callable(fn): | |||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | |||
return fn, None | |||
elif fn in {"train_step", "evaluate_step"}: | |||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||
return self.model, self.model.forward | |||
else: | |||
raise RuntimeError(f"There is no `{fn}` method in your model.") | |||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
@@ -96,12 +101,18 @@ class TorchSingleDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
if isinstance(args.sampler, TorchRandomSampler): | |||
# 如果本来就是随机的,直接替换掉吧。 | |||
sampler = RandomSampler(args.sampler.data_source) | |||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
return dataloader | |||
@@ -164,7 +164,7 @@ def _get_keys(args:List[Dict]) -> List[List[str]]: | |||
return _provided_keys | |||
def _get_fun_msg(fn)->str: | |||
def _get_fun_msg(fn, with_fp=True)->str: | |||
""" | |||
获取函数的基本信息,帮助报错。 | |||
ex: | |||
@@ -172,6 +172,7 @@ def _get_fun_msg(fn)->str: | |||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | |||
:param callable fn: | |||
:param with_fp: 是否包含函数所在的文件信息。 | |||
:return: | |||
""" | |||
if isinstance(fn, functools.partial): | |||
@@ -180,9 +181,12 @@ def _get_fun_msg(fn)->str: | |||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | |||
except: | |||
fn_name = str(fn) | |||
try: | |||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||
except: | |||
if with_fp: | |||
try: | |||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||
except: | |||
fp = '' | |||
else: | |||
fp = '' | |||
msg = f'`{fn_name}`' + fp | |||
return msg | |||
@@ -37,7 +37,7 @@ class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
validate_dataloaders: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
@@ -63,7 +63,7 @@ def model_and_optimizers(request): | |||
shuffle=True | |||
) | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.validate_dataloaders = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2( | |||
device=4, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -473,12 +473,12 @@ def test_trainer_checkpoint_callback_1( | |||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@magic_argv_env_context | |||
@pytest.mark.skip("Skip transformers test for now.") | |||
def test_trainer_checkpoint_callback_2( | |||
driver, | |||
device, | |||
version | |||
): | |||
pytest.skip("Skip transformers test for now.") | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
@@ -40,7 +40,7 @@ class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
validate_dataloaders: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
@@ -66,7 +66,7 @@ def model_and_optimizers(request): | |||
shuffle=True | |||
) | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.validate_dataloaders = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@@ -92,7 +92,7 @@ def test_load_best_model_callback( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -105,7 +105,7 @@ def test_load_best_model_callback( | |||
driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | |||
evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | |||
dataloaders={'dl1': model_and_optimizers.validate_dataloaders}, | |||
dataloaders={'dl1': model_and_optimizers.evaluate_dataloaders}, | |||
metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | |||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | |||
progress_bar='rich', use_dist_sampler=False) | |||
@@ -75,7 +75,7 @@ _dataloader = DataLoader( | |||
shuffle=True | |||
) | |||
train_dataloader = _dataloader | |||
validate_dataloaders = _dataloader | |||
evaluate_dataloaders = _dataloader | |||
metrics = {"acc": Accuracy()} | |||
@@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
device=None, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
evaluate_dataloaders=validate_dataloaders, | |||
evaluate_dataloaders=evaluate_dataloaders, | |||
metrics=metrics, | |||
n_epochs=2, | |||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||
import argparse | |||
import os | |||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||
import sys | |||
path = os.path.abspath(__file__) | |||
@@ -63,7 +63,7 @@ _dataloader = DataLoader( | |||
shuffle=True | |||
) | |||
train_dataloader = _dataloader | |||
validate_dataloaders = _dataloader | |||
evaluate_dataloaders = _dataloader | |||
metrics = {"acc": Accuracy()} | |||
@@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
device=None, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
evaluate_dataloaders=validate_dataloaders, | |||
evaluate_dataloaders=evaluate_dataloaders, | |||
metrics=metrics, | |||
n_epochs=2, | |||
@@ -30,7 +30,7 @@ class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
validate_dataloaders: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
@@ -57,7 +57,7 @@ def model_and_optimizers(): | |||
shuffle=True | |||
) | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.validate_dataloaders = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@@ -82,7 +82,7 @@ def test_trainer_event_trigger( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -43,7 +43,7 @@ class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
validate_dataloaders: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
@@ -71,7 +71,7 @@ def model_and_optimizers(request): | |||
shuffle=True | |||
) | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.validate_dataloaders = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {"acc": Accuracy()} | |||
elif request.param == 1: | |||
@@ -91,7 +91,7 @@ def model_and_optimizers(request): | |||
shuffle=True | |||
) | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.validate_dataloaders = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@@ -116,7 +116,7 @@ def test_trainer_torch_with_evaluator( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -193,7 +193,7 @@ def test_trainer_validate_every( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -38,7 +38,7 @@ class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
validate_dataloaders: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
@@ -65,7 +65,7 @@ def model_and_optimizers(request): | |||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||
shuffle=True | |||
) | |||
trainer_params.validate_dataloaders = None | |||
trainer_params.evaluate_dataloaders = None | |||
trainer_params.input_mapping = None | |||
trainer_params.output_mapping = None | |||
@@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc( | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
@@ -267,7 +267,7 @@ def test_trainer_on_exception( | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.validate_dataloaders, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||