@@ -8,6 +8,8 @@ import torch | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import CheckError | |||
class MetricBase(object): | |||
def __init__(self): | |||
@@ -29,7 +31,7 @@ class MetricBase(object): | |||
if isinstance(value, str): | |||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
def __call__(self, output_dict, target_dict, force_check=False): | |||
""" | |||
:param output_dict: | |||
@@ -67,7 +69,7 @@ class MetricBase(object): | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | |||
self._reverse_param_map = {value:key for key, value in check_res.items()} | |||
for key, value in check_res.items(): | |||
new_value = value.copy() | |||
new_value = list(value) | |||
for idx, func_param in enumerate(value): | |||
if func_param in self._reverse_param_map: | |||
new_value[idx] = self._reverse_param_map[func_param] | |||
@@ -85,28 +87,12 @@ class MetricBase(object): | |||
return metrics | |||
class CheckError(Exception): | |||
def __init__(self, check_res): | |||
err = '' | |||
if check_res.missing: | |||
err += f'Missing: {check_res.missing}\n' | |||
if check_res.duplicated: | |||
err += f'Duplicated: {check_res.duplicated}\n' | |||
self.check_res = check_res | |||
def __str__(self): | |||
pass | |||
class Metric(MetricBase): | |||
def __init__(self, func, key_map, **kwargs): | |||
super().__init__() | |||
pass | |||
def _prepare_metrics(metrics): | |||
""" | |||
@@ -127,8 +113,8 @@ def _prepare_metrics(metrics): | |||
elif isinstance(metrics, MetricBase): | |||
_metrics = [metrics] | |||
else: | |||
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}." | |||
.format(type(metrics))) | |||
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, " | |||
f"got {type(metrics)}.") | |||
return _metrics | |||
@@ -5,12 +5,13 @@ import torch | |||
from torch import nn | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.utils import CheckError | |||
class Tester(object): | |||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
@@ -33,7 +34,7 @@ class Tester(object): | |||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||
f"for evaluation, not `{type(self._predict_func)}`.") | |||
else: | |||
self._predict_func = self._model | |||
self._predict_func = self._model.forward | |||
self.data = data | |||
if torch.cuda.is_available() and self.use_cuda: | |||
@@ -50,14 +51,14 @@ class Tester(object): | |||
def test(self): | |||
# turn on the testing mode; clean up the history | |||
network = self._model | |||
self.mode(network, is_test=True) | |||
self._mode(network, is_test=True) | |||
output, truths = defaultdict(list), defaultdict(list) | |||
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||
prediction = self.data_forward(network, batch_x) | |||
prediction = self._data_forward(self._predict_func, batch_x) | |||
assert isinstance(prediction, dict) | |||
for k, v in prediction.items(): | |||
output[k].append(v) | |||
@@ -68,16 +69,21 @@ class Tester(object): | |||
for k, v in truths.items(): | |||
truths[k] = itertools.chain(*v) | |||
eval_results = {} | |||
try: | |||
for metric in self.metrics: | |||
eval_result = metric(output, truths) | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
except CheckError as e: | |||
pass | |||
if self.verbose >= 0: | |||
print("[tester] \n{}".format(self.format_eval_results(eval_results))) | |||
self.mode(network, is_test=False) | |||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||
self._mode(network, is_test=False) | |||
return eval_results | |||
def mode(self, model, is_test=False): | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
:param model: a PyTorch model | |||
@@ -89,13 +95,13 @@ class Tester(object): | |||
else: | |||
model.train() | |||
def data_forward(self, network, x): | |||
def _data_forward(self, func, x): | |||
"""A forward pass of the model. """ | |||
x = _build_args(network.forward, **x) | |||
y = self._predict_func(**x) | |||
x = _build_args(func, **x) | |||
y = func(**x) | |||
return y | |||
def format_eval_results(self, results): | |||
def _format_eval_results(self, results): | |||
"""Override this method to support more print formats. | |||
:param results: dict, (str: float) is (metrics name: value) | |||
@@ -25,7 +25,7 @@ from fastNLP.core.losses import LossBase | |||
from fastNLP.core.metrics import MetricBase | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.utils import CheckError | |||
class Trainer(object): | |||
"""Main Training Loop | |||
@@ -211,13 +211,11 @@ class Trainer(object): | |||
def get_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
:param predict: prediction label vector | |||
:param truth: ground truth label vector | |||
:param predict: prediction dict, produced by model.forward | |||
:param truth: ground truth dict, produced by batch_y | |||
:return: a scalar | |||
""" | |||
assert isinstance(predict, dict) and isinstance(truth, dict) | |||
args = _build_args(self.loss_func, **predict, **truth) | |||
return self.loss_func(**args) | |||
return self.losser(predict, truth) | |||
def save_model(self, model, model_name, only_param=False): | |||
model_name = os.path.join(self.save_path, model_name) | |||
@@ -260,11 +258,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
dev_data=None, | |||
check_level=WARNING_CHECK_LEVEL): | |||
# check get_loss 方法 | |||
model_name = model.__class__.__name__ | |||
model_devcie = model.parameters().__next__().device | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_syn_model_data(model, batch_x, batch_y) | |||
_move_dict_value_to_device(model_devcie, batch_x, batch_y) | |||
# forward check | |||
if batch_count==0: | |||
_check_forward_error(model_func=model.forward, check_level=check_level, | |||
@@ -277,68 +275,29 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") | |||
# loss check | |||
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss | |||
# 需要保证output与batch_y是无歧义的? | |||
# (1) output和batch_y长度为1 | |||
# (2) output和batch_y的key是和losser接受的完全一致 | |||
pass | |||
loss = losser(output, batch_y) | |||
try: | |||
loss = losser(output, batch_y) | |||
except CheckError as e: | |||
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature, | |||
check_res=e.check_res, output=output, batch_y=batch_y, | |||
check_level=check_level) | |||
# check loss output | |||
if batch_count == 0: | |||
if not isinstance(loss, torch.Tensor): | |||
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.". | |||
format(type(losser), type(loss))) | |||
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " | |||
f"but got `{type(loss)}`.") | |||
if len(loss.size())!=0: | |||
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( | |||
type(losser), loss.size() | |||
)) | |||
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " | |||
f"should be torch.size([])") | |||
loss.backward() | |||
model.zero_grad() | |||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||
break | |||
if dev_data is not None: | |||
outputs, truths = defaultdict(list), defaultdict(list) | |||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
# TODO 这里修改为使用tester | |||
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, ) | |||
with torch.no_grad(): | |||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||
_syn_model_data(model, batch_x, batch_y) | |||
if hasattr(model, 'predict'): | |||
if not callable(model.predict): | |||
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " | |||
f"for evaluation.") | |||
refined_batch_x = _build_args(model.predict, **batch_x) | |||
prev_func = model.predict | |||
output = prev_func(**refined_batch_x) | |||
else: | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
prev_func = model.forward | |||
output = prev_func(**refined_batch_x) | |||
func_signature = get_func_signature(prev_func) | |||
if not isinstance(output, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`") | |||
for k, v in output.items(): | |||
outputs[k].append(v) | |||
for k, v in batch_y.items(): | |||
truths[k].append(v) | |||
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: | |||
break | |||
for k, v in outputs.items(): | |||
outputs[k] = tuple(itertools.chain(*v)) | |||
for k, v in truths.items(): | |||
truths[k] = tuple(itertools.chain(*v)) | |||
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug | |||
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||
batch_size=batch_size, verbose=-1) | |||
tester.test() | |||
def _check_forward_error(model_func, check_level, batch_x): | |||
@@ -346,11 +305,11 @@ def _check_forward_error(model_func, check_level, batch_x): | |||
_missing = '' | |||
_unused = '' | |||
func_signature = get_func_signature(model_func) | |||
if len(check_res.missing)!=0: | |||
if len(check_res['missing'])!=0: | |||
_missing = "Function {} misses {}, only provided with {}, " \ | |||
".\n".format(func_signature, check_res.missing, | |||
list(batch_x.keys())) | |||
if len(check_res.unused)!=0: | |||
if len(check_res['unused'])!=0: | |||
if len(check_res.unused) > 1: | |||
_unused = "{} are not used ".format(check_res.unused) | |||
else: | |||
@@ -370,9 +329,7 @@ def _check_forward_error(model_func, check_level, batch_x): | |||
elif check_level == WARNING_CHECK_LEVEL: | |||
warnings.warn(message=_unused) | |||
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): | |||
check_res = _check_arg_dict_list(func, [output, batch_y]) | |||
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): | |||
_missing = '' | |||
_unused = '' | |||
_duplicated = '' | |||
@@ -220,13 +220,16 @@ class CheckError(Exception): | |||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res): | |||
def __init__(self, check_res:CheckRes, func_signature:str): | |||
err = '' | |||
if check_res['missing']: | |||
err += f"Missing: {check_res['missing']}\n" | |||
if check_res['duplicated']: | |||
err += f"Duplicated: {check_res['duplicated']}\n" | |||
if check_res['unused']: | |||
err += f"Unused: {check_res['unused']}\n" | |||
if check_res.missing: | |||
err += f"Missing: {check_res.missing}\n" | |||
if check_res.duplicated: | |||
err += f"Duplicated: {check_res.duplicated}\n" | |||
if check_res.unused: | |||
err += f"Unused: {check_res.unused}\n" | |||
Exception.__init__(self, err) | |||
self.check_res = check_res | |||
self.func_signature = func_signature |