|
@@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
import time |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
except: |
|
|
|
|
|
from .utils import _pseudo_tqdm as tqdm |
|
|
|
|
|
|
|
|
from .batch import BatchIter, DataSetIter |
|
|
from .batch import BatchIter, DataSetIter |
|
|
from .dataset import DataSet |
|
|
from .dataset import DataSet |
|
|
from .metrics import _prepare_metrics |
|
|
from .metrics import _prepare_metrics |
|
@@ -47,7 +54,7 @@ from .utils import _get_func_signature |
|
|
from .utils import _get_model_device |
|
|
from .utils import _get_model_device |
|
|
from .utils import _move_model_to_device |
|
|
from .utils import _move_model_to_device |
|
|
from ._parallel_utils import _data_parallel_wrapper |
|
|
from ._parallel_utils import _data_parallel_wrapper |
|
|
from fastNLP.core._parallel_utils import _model_contains_inner_module |
|
|
|
|
|
|
|
|
from ._parallel_utils import _model_contains_inner_module |
|
|
from functools import partial |
|
|
from functools import partial |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
__all__ = [ |
|
@@ -80,9 +87,10 @@ class Tester(object): |
|
|
|
|
|
|
|
|
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 |
|
|
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 |
|
|
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 |
|
|
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 |
|
|
|
|
|
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): |
|
|
|
|
|
|
|
|
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): |
|
|
super(Tester, self).__init__() |
|
|
super(Tester, self).__init__() |
|
|
|
|
|
|
|
|
if not isinstance(model, nn.Module): |
|
|
if not isinstance(model, nn.Module): |
|
@@ -94,6 +102,7 @@ class Tester(object): |
|
|
self._model = _move_model_to_device(model, device=device) |
|
|
self._model = _move_model_to_device(model, device=device) |
|
|
self.batch_size = batch_size |
|
|
self.batch_size = batch_size |
|
|
self.verbose = verbose |
|
|
self.verbose = verbose |
|
|
|
|
|
self.use_tqdm = use_tqdm |
|
|
|
|
|
|
|
|
if isinstance(data, DataSet): |
|
|
if isinstance(data, DataSet): |
|
|
self.data_iterator = DataSetIter( |
|
|
self.data_iterator = DataSetIter( |
|
@@ -141,21 +150,39 @@ class Tester(object): |
|
|
eval_results = {} |
|
|
eval_results = {} |
|
|
try: |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
with torch.no_grad(): |
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
|
|
|
pred_dict = self._data_forward(self._predict_func, batch_x) |
|
|
|
|
|
if not isinstance(pred_dict, dict): |
|
|
|
|
|
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " |
|
|
|
|
|
f"must be `dict`, got {type(pred_dict)}.") |
|
|
|
|
|
|
|
|
if not self.use_tqdm: |
|
|
|
|
|
from .utils import _pseudo_tqdm as inner_tqdm |
|
|
|
|
|
else: |
|
|
|
|
|
inner_tqdm = tqdm |
|
|
|
|
|
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: |
|
|
|
|
|
pbar.set_description_str(desc="Test") |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
|
|
|
pred_dict = self._data_forward(self._predict_func, batch_x) |
|
|
|
|
|
if not isinstance(pred_dict, dict): |
|
|
|
|
|
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " |
|
|
|
|
|
f"must be `dict`, got {type(pred_dict)}.") |
|
|
|
|
|
for metric in self.metrics: |
|
|
|
|
|
metric(pred_dict, batch_y) |
|
|
|
|
|
|
|
|
|
|
|
if self.use_tqdm: |
|
|
|
|
|
pbar.update() |
|
|
|
|
|
|
|
|
for metric in self.metrics: |
|
|
for metric in self.metrics: |
|
|
metric(pred_dict, batch_y) |
|
|
|
|
|
for metric in self.metrics: |
|
|
|
|
|
eval_result = metric.get_metric() |
|
|
|
|
|
if not isinstance(eval_result, dict): |
|
|
|
|
|
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " |
|
|
|
|
|
f"`dict`, got {type(eval_result)}") |
|
|
|
|
|
metric_name = metric.__class__.__name__ |
|
|
|
|
|
eval_results[metric_name] = eval_result |
|
|
|
|
|
|
|
|
eval_result = metric.get_metric() |
|
|
|
|
|
if not isinstance(eval_result, dict): |
|
|
|
|
|
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " |
|
|
|
|
|
f"`dict`, got {type(eval_result)}") |
|
|
|
|
|
metric_name = metric.__class__.__name__ |
|
|
|
|
|
eval_results[metric_name] = eval_result |
|
|
|
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
|
|
|
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' |
|
|
|
|
|
pbar.write(test_str) |
|
|
|
|
|
pbar.close() |
|
|
except _CheckError as e: |
|
|
except _CheckError as e: |
|
|
prev_func_signature = _get_func_signature(self._predict_func) |
|
|
prev_func_signature = _get_func_signature(self._predict_func) |
|
|
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, |
|
|
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, |
|
|