@@ -8,6 +8,8 @@ import torch | |||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import CheckError | |||||
class MetricBase(object): | class MetricBase(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -29,7 +31,7 @@ class MetricBase(object): | |||||
if isinstance(value, str): | if isinstance(value, str): | ||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | self.param_map[key] = value | ||||
def __call__(self, output_dict, target_dict, force_check=False): | def __call__(self, output_dict, target_dict, force_check=False): | ||||
""" | """ | ||||
:param output_dict: | :param output_dict: | ||||
@@ -67,7 +69,7 @@ class MetricBase(object): | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | ||||
self._reverse_param_map = {value:key for key, value in check_res.items()} | self._reverse_param_map = {value:key for key, value in check_res.items()} | ||||
for key, value in check_res.items(): | for key, value in check_res.items(): | ||||
new_value = value.copy() | |||||
new_value = list(value) | |||||
for idx, func_param in enumerate(value): | for idx, func_param in enumerate(value): | ||||
if func_param in self._reverse_param_map: | if func_param in self._reverse_param_map: | ||||
new_value[idx] = self._reverse_param_map[func_param] | new_value[idx] = self._reverse_param_map[func_param] | ||||
@@ -85,28 +87,12 @@ class MetricBase(object): | |||||
return metrics | return metrics | ||||
class CheckError(Exception): | |||||
def __init__(self, check_res): | |||||
err = '' | |||||
if check_res.missing: | |||||
err += f'Missing: {check_res.missing}\n' | |||||
if check_res.duplicated: | |||||
err += f'Duplicated: {check_res.duplicated}\n' | |||||
self.check_res = check_res | |||||
def __str__(self): | |||||
pass | |||||
class Metric(MetricBase): | class Metric(MetricBase): | ||||
def __init__(self, func, key_map, **kwargs): | def __init__(self, func, key_map, **kwargs): | ||||
super().__init__() | super().__init__() | ||||
pass | pass | ||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -127,8 +113,8 @@ def _prepare_metrics(metrics): | |||||
elif isinstance(metrics, MetricBase): | elif isinstance(metrics, MetricBase): | ||||
_metrics = [metrics] | _metrics = [metrics] | ||||
else: | else: | ||||
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}." | |||||
.format(type(metrics))) | |||||
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, " | |||||
f"got {type(metrics)}.") | |||||
return _metrics | return _metrics | ||||
@@ -5,12 +5,13 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.utils import CheckError | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
@@ -33,7 +34,7 @@ class Tester(object): | |||||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | raise TypeError(f"`{_model_name}.predict` must be callable to be used " | ||||
f"for evaluation, not `{type(self._predict_func)}`.") | f"for evaluation, not `{type(self._predict_func)}`.") | ||||
else: | else: | ||||
self._predict_func = self._model | |||||
self._predict_func = self._model.forward | |||||
self.data = data | self.data = data | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
@@ -50,14 +51,14 @@ class Tester(object): | |||||
def test(self): | def test(self): | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
network = self._model | network = self._model | ||||
self.mode(network, is_test=True) | |||||
self._mode(network, is_test=True) | |||||
output, truths = defaultdict(list), defaultdict(list) | output, truths = defaultdict(list), defaultdict(list) | ||||
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | _move_dict_value_to_device(self._model_device, batch_x, batch_y) | ||||
prediction = self.data_forward(network, batch_x) | |||||
prediction = self._data_forward(self._predict_func, batch_x) | |||||
assert isinstance(prediction, dict) | assert isinstance(prediction, dict) | ||||
for k, v in prediction.items(): | for k, v in prediction.items(): | ||||
output[k].append(v) | output[k].append(v) | ||||
@@ -68,16 +69,21 @@ class Tester(object): | |||||
for k, v in truths.items(): | for k, v in truths.items(): | ||||
truths[k] = itertools.chain(*v) | truths[k] = itertools.chain(*v) | ||||
eval_results = {} | eval_results = {} | ||||
try: | |||||
for metric in self.metrics: | for metric in self.metrics: | ||||
eval_result = metric(output, truths) | eval_result = metric(output, truths) | ||||
metric_name = metric.__class__.__name__ | metric_name = metric.__class__.__name__ | ||||
eval_results[metric_name] = eval_result | eval_results[metric_name] = eval_result | ||||
except CheckError as e: | |||||
pass | |||||
if self.verbose >= 0: | if self.verbose >= 0: | ||||
print("[tester] \n{}".format(self.format_eval_results(eval_results))) | |||||
self.mode(network, is_test=False) | |||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||||
self._mode(network, is_test=False) | |||||
return eval_results | return eval_results | ||||
def mode(self, model, is_test=False): | |||||
def _mode(self, model, is_test=False): | |||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
@@ -89,13 +95,13 @@ class Tester(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def data_forward(self, network, x): | |||||
def _data_forward(self, func, x): | |||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
x = _build_args(network.forward, **x) | |||||
y = self._predict_func(**x) | |||||
x = _build_args(func, **x) | |||||
y = func(**x) | |||||
return y | return y | ||||
def format_eval_results(self, results): | |||||
def _format_eval_results(self, results): | |||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
:param results: dict, (str: float) is (metrics name: value) | :param results: dict, (str: float) is (metrics name: value) | ||||
@@ -25,7 +25,7 @@ from fastNLP.core.losses import LossBase | |||||
from fastNLP.core.metrics import MetricBase | from fastNLP.core.metrics import MetricBase | ||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.utils import CheckError | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
@@ -211,13 +211,11 @@ class Trainer(object): | |||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
:param predict: prediction label vector | |||||
:param truth: ground truth label vector | |||||
:param predict: prediction dict, produced by model.forward | |||||
:param truth: ground truth dict, produced by batch_y | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
assert isinstance(predict, dict) and isinstance(truth, dict) | |||||
args = _build_args(self.loss_func, **predict, **truth) | |||||
return self.loss_func(**args) | |||||
return self.losser(predict, truth) | |||||
def save_model(self, model, model_name, only_param=False): | def save_model(self, model, model_name, only_param=False): | ||||
model_name = os.path.join(self.save_path, model_name) | model_name = os.path.join(self.save_path, model_name) | ||||
@@ -260,11 +258,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
dev_data=None, | dev_data=None, | ||||
check_level=WARNING_CHECK_LEVEL): | check_level=WARNING_CHECK_LEVEL): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_name = model.__class__.__name__ | |||||
model_devcie = model.parameters().__next__().device | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_syn_model_data(model, batch_x, batch_y) | |||||
_move_dict_value_to_device(model_devcie, batch_x, batch_y) | |||||
# forward check | # forward check | ||||
if batch_count==0: | if batch_count==0: | ||||
_check_forward_error(model_func=model.forward, check_level=check_level, | _check_forward_error(model_func=model.forward, check_level=check_level, | ||||
@@ -277,68 +275,29 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") | raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") | ||||
# loss check | # loss check | ||||
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss | |||||
# 需要保证output与batch_y是无歧义的? | |||||
# (1) output和batch_y长度为1 | |||||
# (2) output和batch_y的key是和losser接受的完全一致 | |||||
pass | |||||
loss = losser(output, batch_y) | |||||
try: | |||||
loss = losser(output, batch_y) | |||||
except CheckError as e: | |||||
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature, | |||||
check_res=e.check_res, output=output, batch_y=batch_y, | |||||
check_level=check_level) | |||||
# check loss output | # check loss output | ||||
if batch_count == 0: | if batch_count == 0: | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.". | |||||
format(type(losser), type(loss))) | |||||
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " | |||||
f"but got `{type(loss)}`.") | |||||
if len(loss.size())!=0: | if len(loss.size())!=0: | ||||
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( | |||||
type(losser), loss.size() | |||||
)) | |||||
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " | |||||
f"should be torch.size([])") | |||||
loss.backward() | loss.backward() | ||||
model.zero_grad() | model.zero_grad() | ||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | ||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
outputs, truths = defaultdict(list), defaultdict(list) | |||||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
# TODO 这里修改为使用tester | |||||
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, ) | |||||
with torch.no_grad(): | |||||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||||
_syn_model_data(model, batch_x, batch_y) | |||||
if hasattr(model, 'predict'): | |||||
if not callable(model.predict): | |||||
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " | |||||
f"for evaluation.") | |||||
refined_batch_x = _build_args(model.predict, **batch_x) | |||||
prev_func = model.predict | |||||
output = prev_func(**refined_batch_x) | |||||
else: | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
prev_func = model.forward | |||||
output = prev_func(**refined_batch_x) | |||||
func_signature = get_func_signature(prev_func) | |||||
if not isinstance(output, dict): | |||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`") | |||||
for k, v in output.items(): | |||||
outputs[k].append(v) | |||||
for k, v in batch_y.items(): | |||||
truths[k].append(v) | |||||
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: | |||||
break | |||||
for k, v in outputs.items(): | |||||
outputs[k] = tuple(itertools.chain(*v)) | |||||
for k, v in truths.items(): | |||||
truths[k] = tuple(itertools.chain(*v)) | |||||
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug | |||||
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
batch_size=batch_size, verbose=-1) | |||||
tester.test() | |||||
def _check_forward_error(model_func, check_level, batch_x): | def _check_forward_error(model_func, check_level, batch_x): | ||||
@@ -346,11 +305,11 @@ def _check_forward_error(model_func, check_level, batch_x): | |||||
_missing = '' | _missing = '' | ||||
_unused = '' | _unused = '' | ||||
func_signature = get_func_signature(model_func) | func_signature = get_func_signature(model_func) | ||||
if len(check_res.missing)!=0: | |||||
if len(check_res['missing'])!=0: | |||||
_missing = "Function {} misses {}, only provided with {}, " \ | _missing = "Function {} misses {}, only provided with {}, " \ | ||||
".\n".format(func_signature, check_res.missing, | ".\n".format(func_signature, check_res.missing, | ||||
list(batch_x.keys())) | list(batch_x.keys())) | ||||
if len(check_res.unused)!=0: | |||||
if len(check_res['unused'])!=0: | |||||
if len(check_res.unused) > 1: | if len(check_res.unused) > 1: | ||||
_unused = "{} are not used ".format(check_res.unused) | _unused = "{} are not used ".format(check_res.unused) | ||||
else: | else: | ||||
@@ -370,9 +329,7 @@ def _check_forward_error(model_func, check_level, batch_x): | |||||
elif check_level == WARNING_CHECK_LEVEL: | elif check_level == WARNING_CHECK_LEVEL: | ||||
warnings.warn(message=_unused) | warnings.warn(message=_unused) | ||||
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): | |||||
check_res = _check_arg_dict_list(func, [output, batch_y]) | |||||
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): | |||||
_missing = '' | _missing = '' | ||||
_unused = '' | _unused = '' | ||||
_duplicated = '' | _duplicated = '' | ||||
@@ -220,13 +220,16 @@ class CheckError(Exception): | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res): | |||||
def __init__(self, check_res:CheckRes, func_signature:str): | |||||
err = '' | err = '' | ||||
if check_res['missing']: | |||||
err += f"Missing: {check_res['missing']}\n" | |||||
if check_res['duplicated']: | |||||
err += f"Duplicated: {check_res['duplicated']}\n" | |||||
if check_res['unused']: | |||||
err += f"Unused: {check_res['unused']}\n" | |||||
if check_res.missing: | |||||
err += f"Missing: {check_res.missing}\n" | |||||
if check_res.duplicated: | |||||
err += f"Duplicated: {check_res.duplicated}\n" | |||||
if check_res.unused: | |||||
err += f"Unused: {check_res.unused}\n" | |||||
Exception.__init__(self, err) | Exception.__init__(self, err) | ||||
self.check_res = check_res | self.check_res = check_res | ||||
self.func_signature = func_signature |