diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 067ff30c..691bf2ae 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation """ +import time + import torch import torch.nn as nn +try: + from tqdm.auto import tqdm +except: + from .utils import _pseudo_tqdm as tqdm + from .batch import BatchIter, DataSetIter from .dataset import DataSet from .metrics import _prepare_metrics @@ -47,7 +54,7 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device from ._parallel_utils import _data_parallel_wrapper -from fastNLP.core._parallel_utils import _model_contains_inner_module +from ._parallel_utils import _model_contains_inner_module from functools import partial __all__ = [ @@ -80,9 +87,10 @@ class Tester(object): 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 + :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 """ - def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): + def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): super(Tester, self).__init__() if not isinstance(model, nn.Module): @@ -94,6 +102,7 @@ class Tester(object): self._model = _move_model_to_device(model, device=device) self.batch_size = batch_size self.verbose = verbose + self.use_tqdm = use_tqdm if isinstance(data, DataSet): self.data_iterator = DataSetIter( @@ -141,21 +150,39 @@ class Tester(object): eval_results = {} try: with torch.no_grad(): - for batch_x, batch_y in data_iterator: - _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) - pred_dict = self._data_forward(self._predict_func, batch_x) - if not isinstance(pred_dict, dict): - raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " - f"must be `dict`, got {type(pred_dict)}.") + if not self.use_tqdm: + from .utils import _pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: + pbar.set_description_str(desc="Test") + + start_time = time.time() + + for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) + pred_dict = self._data_forward(self._predict_func, batch_x) + if not isinstance(pred_dict, dict): + raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " + f"must be `dict`, got {type(pred_dict)}.") + for metric in self.metrics: + metric(pred_dict, batch_y) + + if self.use_tqdm: + pbar.update() + for metric in self.metrics: - metric(pred_dict, batch_y) - for metric in self.metrics: - eval_result = metric.get_metric() - if not isinstance(eval_result, dict): - raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " - f"`dict`, got {type(eval_result)}") - metric_name = metric.__class__.__name__ - eval_results[metric_name] = eval_result + eval_result = metric.get_metric() + if not isinstance(eval_result, dict): + raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " + f"`dict`, got {type(eval_result)}") + metric_name = metric.__class__.__name__ + eval_results[metric_name] = eval_result + + end_time = time.time() + test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' + pbar.write(test_str) + pbar.close() except _CheckError as e: prev_func_signature = _get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 83bdb4b0..a85b7fee 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -352,7 +352,7 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device -from fastNLP.core._parallel_utils import _model_contains_inner_module +from ._parallel_utils import _model_contains_inner_module class Trainer(object): @@ -557,7 +557,8 @@ class Trainer(object): metrics=self.metrics, batch_size=self.batch_size, device=None, # 由上面的部分处理device - verbose=0) + verbose=0, + use_tqdm=self.use_tqdm) self.step = 0 self.start_time = None # start timestamp @@ -633,7 +634,7 @@ class Trainer(object): def _train(self): if not self.use_tqdm: - from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm + from .utils import _pseudo_tqdm as inner_tqdm else: inner_tqdm = tqdm self.step = 0 @@ -859,8 +860,11 @@ def _get_value_info(_dict): strs.append(_str) return strs + from numbers import Number from .batch import _to_tensor + + def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法