diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 067ff30c..691bf2ae 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation """ +import time + import torch import torch.nn as nn +try: + from tqdm.auto import tqdm +except: + from .utils import _pseudo_tqdm as tqdm + from .batch import BatchIter, DataSetIter from .dataset import DataSet from .metrics import _prepare_metrics @@ -47,7 +54,7 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device from ._parallel_utils import _data_parallel_wrapper -from fastNLP.core._parallel_utils import _model_contains_inner_module +from ._parallel_utils import _model_contains_inner_module from functools import partial __all__ = [ @@ -80,9 +87,10 @@ class Tester(object): 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 + :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 """ - def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): + def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): super(Tester, self).__init__() if not isinstance(model, nn.Module): @@ -94,6 +102,7 @@ class Tester(object): self._model = _move_model_to_device(model, device=device) self.batch_size = batch_size self.verbose = verbose + self.use_tqdm = use_tqdm if isinstance(data, DataSet): self.data_iterator = DataSetIter( @@ -141,21 +150,39 @@ class Tester(object): eval_results = {} try: with torch.no_grad(): - for batch_x, batch_y in data_iterator: - _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) - pred_dict = self._data_forward(self._predict_func, batch_x) - if not isinstance(pred_dict, dict): - raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " - f"must be `dict`, got {type(pred_dict)}.") + if not self.use_tqdm: + from .utils import _pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: + pbar.set_description_str(desc="Test") + + start_time = time.time() + + for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) + pred_dict = self._data_forward(self._predict_func, batch_x) + if not isinstance(pred_dict, dict): + raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " + f"must be `dict`, got {type(pred_dict)}.") + for metric in self.metrics: + metric(pred_dict, batch_y) + + if self.use_tqdm: + pbar.update() + for metric in self.metrics: - metric(pred_dict, batch_y) - for metric in self.metrics: - eval_result = metric.get_metric() - if not isinstance(eval_result, dict): - raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " - f"`dict`, got {type(eval_result)}") - metric_name = metric.__class__.__name__ - eval_results[metric_name] = eval_result + eval_result = metric.get_metric() + if not isinstance(eval_result, dict): + raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " + f"`dict`, got {type(eval_result)}") + metric_name = metric.__class__.__name__ + eval_results[metric_name] = eval_result + + end_time = time.time() + test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' + pbar.write(test_str) + pbar.close() except _CheckError as e: prev_func_signature = _get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 83bdb4b0..a85b7fee 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -352,7 +352,7 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device -from fastNLP.core._parallel_utils import _model_contains_inner_module +from ._parallel_utils import _model_contains_inner_module class Trainer(object): @@ -557,7 +557,8 @@ class Trainer(object): metrics=self.metrics, batch_size=self.batch_size, device=None, # 由上面的部分处理device - verbose=0) + verbose=0, + use_tqdm=self.use_tqdm) self.step = 0 self.start_time = None # start timestamp @@ -633,7 +634,7 @@ class Trainer(object): def _train(self): if not self.use_tqdm: - from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm + from .utils import _pseudo_tqdm as inner_tqdm else: inner_tqdm = tqdm self.step = 0 @@ -859,8 +860,11 @@ def _get_value_info(_dict): strs.append(_str) return strs + from numbers import Number from .batch import _to_tensor + + def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index 8bdfb9fe..96ea7a10 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -11,6 +11,13 @@ LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding] AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models](https://arxiv.org/pdf/1708.02182.pdf) +#数据集来源 +IMDB:http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz +SST-2:https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8 +SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip +yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M +yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M + # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index 0b8fc535..3482de70 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -203,7 +203,7 @@ callbacks.append( def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, - n_epochs=num_epochs) + n_epochs=num_epochs,callbacks=callbacks) print(trainer.train())