@@ -1,6 +1,7 @@ | |||||
import warnings | import warnings | ||||
import inspect | import inspect | ||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -21,6 +22,7 @@ class MetricBase(object): | |||||
def _init_param_map(self, key_map, **kwargs): | def _init_param_map(self, key_map, **kwargs): | ||||
self.param_map = {} | self.param_map = {} | ||||
value_counter = defaultdict(0) | |||||
for key, value in key_map.items(): | for key, value in key_map.items(): | ||||
if isinstance(key, str): | if isinstance(key, str): | ||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | ||||
@@ -32,16 +34,19 @@ class MetricBase(object): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | self.param_map[key] = value | ||||
def __call__(self, output_dict, target_dict, force_check=False): | |||||
def __call__(self, output_dict, target_dict, check=False): | |||||
""" | """ | ||||
:param output_dict: | :param output_dict: | ||||
:param target_dict: | :param target_dict: | ||||
:param check: boolean, | |||||
:return: | :return: | ||||
""" | """ | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
if not self._checked: | if not self._checked: | ||||
# 0. check param_map does not have same value | |||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = func_spect.args | func_args = func_spect.args | ||||
@@ -65,7 +70,7 @@ class MetricBase(object): | |||||
mapped_target_dict[func_arg] = target_dict[input_arg] | mapped_target_dict[func_arg] = target_dict[input_arg] | ||||
# check duplicated, unused, missing | # check duplicated, unused, missing | ||||
if force_check or not self._checked: | |||||
if check or not self._checked: | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | ||||
self._reverse_param_map = {value:key for key, value in check_res.items()} | self._reverse_param_map = {value:key for key, value in check_res.items()} | ||||
for key, value in check_res.items(): | for key, value in check_res.items(): | ||||
@@ -73,8 +78,9 @@ class MetricBase(object): | |||||
for idx, func_param in enumerate(value): | for idx, func_param in enumerate(value): | ||||
if func_param in self._reverse_param_map: | if func_param in self._reverse_param_map: | ||||
new_value[idx] = self._reverse_param_map[func_param] | new_value[idx] = self._reverse_param_map[func_param] | ||||
if check_res.missing or check_res.duplicated: | |||||
raise CheckError(check_res=check_res) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.evaluate)) | |||||
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) | ||||
metrics = self.evaluate(**refined_args) | metrics = self.evaluate(**refined_args) | ||||
@@ -92,7 +98,6 @@ class Metric(MetricBase): | |||||
super().__init__() | super().__init__() | ||||
pass | pass | ||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -12,6 +12,7 @@ from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
@@ -47,7 +48,6 @@ class Tester(object): | |||||
self._model_device = model.parameters().__next__().device | self._model_device = model.parameters().__next__().device | ||||
def test(self): | def test(self): | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
network = self._model | network = self._model | ||||
@@ -75,7 +75,9 @@ class Tester(object): | |||||
metric_name = metric.__class__.__name__ | metric_name = metric.__class__.__name__ | ||||
eval_results[metric_name] = eval_result | eval_results[metric_name] = eval_result | ||||
except CheckError as e: | except CheckError as e: | ||||
pass | |||||
prev_func_signature = get_func_signature(self._predict_func) | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||||
check_res=e.check_res, output=output, batch_y=truths) | |||||
if self.verbose >= 0: | if self.verbose >= 0: | ||||
@@ -20,12 +20,11 @@ from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.losses import LossBase | |||||
from fastNLP.core.metrics import MetricBase | |||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _check_forward_error | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
@@ -33,7 +32,7 @@ class Trainer(object): | |||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | ||||
dev_data=None, use_cuda=False, save_path="./save", | dev_data=None, use_cuda=False, save_path="./save", | ||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, | |||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0, | |||||
**kwargs): | **kwargs): | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -53,8 +52,9 @@ class Trainer(object): | |||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(losser) | losser = _prepare_losser(losser) | ||||
if need_check_code: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) | |||||
if check_code_level>-1: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
check_level=check_code_level) | |||||
self.train_data = train_data | self.train_data = train_data | ||||
self.dev_data = dev_data # If None, No validation. | self.dev_data = dev_data # If None, No validation. | ||||
@@ -250,13 +250,9 @@ class Trainer(object): | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | ||||
dev_data=None, | dev_data=None, | ||||
check_level=WARNING_CHECK_LEVEL): | |||||
check_level=0): | |||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_devcie = model.parameters().__next__().device | model_devcie = model.parameters().__next__().device | ||||
@@ -265,7 +261,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
_move_dict_value_to_device(model_devcie, batch_x, batch_y) | _move_dict_value_to_device(model_devcie, batch_x, batch_y) | ||||
# forward check | # forward check | ||||
if batch_count==0: | if batch_count==0: | ||||
_check_forward_error(model_func=model.forward, check_level=check_level, | |||||
_check_forward_error(forward_func=model.forward, check_level=check_level, | |||||
batch_x=batch_x) | batch_x=batch_x) | ||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
@@ -277,19 +273,21 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
# loss check | # loss check | ||||
try: | try: | ||||
loss = losser(output, batch_y) | loss = losser(output, batch_y) | ||||
# check loss output | |||||
if batch_count == 0: | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError( | |||||
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||||
f"but got `{type(loss)}`.") | |||||
if len(loss.size()) != 0: | |||||
raise ValueError( | |||||
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " | |||||
f"should be torch.size([])") | |||||
loss.backward() | |||||
except CheckError as e: | except CheckError as e: | ||||
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature, | _check_loss_evaluate(prev_func=model.forward, func=e.func_signature, | ||||
check_res=e.check_res, output=output, batch_y=batch_y, | check_res=e.check_res, output=output, batch_y=batch_y, | ||||
check_level=check_level) | check_level=check_level) | ||||
# check loss output | |||||
if batch_count == 0: | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " | |||||
f"but got `{type(loss)}`.") | |||||
if len(loss.size())!=0: | |||||
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " | |||||
f"should be torch.size([])") | |||||
loss.backward() | |||||
model.zero_grad() | model.zero_grad() | ||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | ||||
break | break | ||||
@@ -300,93 +298,5 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
tester.test() | tester.test() | ||||
def _check_forward_error(model_func, check_level, batch_x): | |||||
check_res = _check_arg_dict_list(model_func, batch_x) | |||||
_missing = '' | |||||
_unused = '' | |||||
func_signature = get_func_signature(model_func) | |||||
if len(check_res['missing'])!=0: | |||||
_missing = "Function {} misses {}, only provided with {}, " \ | |||||
".\n".format(func_signature, check_res.missing, | |||||
list(batch_x.keys())) | |||||
if len(check_res['unused'])!=0: | |||||
if len(check_res.unused) > 1: | |||||
_unused = "{} are not used ".format(check_res.unused) | |||||
else: | |||||
_unused = "{} is not used ".format(check_res.unused) | |||||
_unused += "in function {}.\n".format(func_signature) | |||||
if _missing: | |||||
if len(_unused)>0 and STRICT_CHECK_LEVEL: | |||||
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) | |||||
else: | |||||
_error_str = _missing | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
raise TypeError(_error_str) | |||||
if _unused: | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
raise ValueError(_unused) | |||||
elif check_level == WARNING_CHECK_LEVEL: | |||||
warnings.warn(message=_unused) | |||||
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): | |||||
_missing = '' | |||||
_unused = '' | |||||
_duplicated = '' | |||||
func_signature = get_func_signature(func) | |||||
prev_func_signature = get_func_signature(prev_func) | |||||
if len(check_res.missing)>0: | |||||
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ | |||||
"{}(from target in Dataset)." \ | |||||
.format(func_signature, check_res.missing, | |||||
list(output.keys()), prev_func_signature, | |||||
list(batch_y.keys())) | |||||
if len(check_res.unused)>0: | |||||
if len(check_res.unused) > 1: | |||||
_unused = "{} are not used ".format(check_res.unused) | |||||
else: | |||||
_unused = "{} is not used ".format(check_res.unused) | |||||
_unused += "in function {}.\n".format(func_signature) | |||||
if len(check_res.duplicated)>0: | |||||
if len(check_res.duplicated) > 1: | |||||
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ | |||||
"them in {} at the same time.".format(check_res.duplicated, | |||||
func_signature, | |||||
check_res.duplicated, | |||||
prev_func_signature) | |||||
else: | |||||
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ | |||||
"it in {} at the same time.".format(check_res.duplicated, | |||||
func_signature, | |||||
check_res.duplicated, | |||||
prev_func_signature) | |||||
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) | |||||
if _number_errs > 0: | |||||
_error_strs = [] | |||||
if _number_errs > 1: | |||||
count = 0 | |||||
order_words = ['Firstly', 'Secondly', 'Thirdly'] | |||||
if _missing: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _missing)) | |||||
count += 1 | |||||
if _duplicated: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _duplicated)) | |||||
count += 1 | |||||
if _unused and check_level == STRICT_CHECK_LEVEL: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _unused)) | |||||
else: | |||||
if _unused: | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
_error_strs.append(_unused) | |||||
elif check_level == WARNING_CHECK_LEVEL: | |||||
_unused = _unused.strip() | |||||
warnings.warn(_unused) | |||||
else: | |||||
if _missing: | |||||
_error_strs.append(_missing) | |||||
if _duplicated: | |||||
_error_strs.append(_duplicated) | |||||
if _error_strs: | |||||
raise ValueError('\n' + '\n'.join(_error_strs)) | |||||
@@ -1,11 +1,14 @@ | |||||
import _pickle | import _pickle | ||||
import inspect | import inspect | ||||
import os | import os | ||||
import warnings | |||||
from collections import Counter | from collections import Counter | ||||
from collections import namedtuple | from collections import namedtuple | ||||
import torch | import torch | ||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||||
'varargs'], verbose=False) | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -105,7 +108,6 @@ def _check_arg_dict_list(func, args): | |||||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | ||||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) | |||||
all_args = set([arg for arg in spect.args if arg!='self']) | all_args = set([arg for arg in spect.args if arg!='self']) | ||||
defaults = [] | defaults = [] | ||||
if spect.defaults is not None: | if spect.defaults is not None: | ||||
@@ -125,7 +127,8 @@ def _check_arg_dict_list(func, args): | |||||
unused=unused, | unused=unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=list(require_args), | required=list(require_args), | ||||
all_needed=list(all_args)) | |||||
all_needed=list(all_args), | |||||
varargs=[arg for arg in spect.varargs]) | |||||
def get_func_signature(func): | def get_func_signature(func): | ||||
""" | """ | ||||
@@ -221,15 +224,73 @@ class CheckError(Exception): | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res:CheckRes, func_signature:str): | def __init__(self, check_res:CheckRes, func_signature:str): | ||||
err = '' | |||||
errs = [f'The following problems occurred when calling {func_signature}'] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||||
if check_res.missing: | if check_res.missing: | ||||
err += f"Missing: {check_res.missing}\n" | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
if check_res.duplicated: | if check_res.duplicated: | ||||
err += f"Duplicated: {check_res.duplicated}\n" | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||||
if check_res.unused: | if check_res.unused: | ||||
err += f"Unused: {check_res.unused}\n" | |||||
errs.append(f"\tunused param: {check_res.unused}") | |||||
Exception.__init__(self, err) | |||||
Exception.__init__(self, '\n'.join(errs)) | |||||
self.check_res = check_res | self.check_res = check_res | ||||
self.func_signature = func_signature | self.func_signature = func_signature | ||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, | |||||
output:dict, batch_y:dict, check_level=0): | |||||
errs = [] | |||||
_unused = [] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " | |||||
f"please delete it.)") | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}" | |||||
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).") | |||||
if check_res.duplicated: | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " | |||||
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") | |||||
if check_res.unused: | |||||
_unused = [f"\tunused param: {check_res.unused}"] | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(_unused) | |||||
if len(errs)>0: | |||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||||
raise NameError('\n'.join(errs)) | |||||
if _unused: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
_unused_warn = _unused[0] + f' in {func_signature}.' | |||||
warnings.warn(message=_unused_warn) | |||||
def _check_forward_error(forward_func, batch_x, check_level): | |||||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||||
func_signature = get_func_signature(forward_func) | |||||
errs = [] | |||||
_unused = [] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.") | |||||
if check_res.unused: | |||||
_unused = [f"\tunused param: {check_res.unused}"] | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(_unused) | |||||
if len(errs)>0: | |||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||||
raise NameError('\n'.join(errs)) | |||||
if _unused: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
_unused_warn = _unused[0] + f' in {func_signature}.' | |||||
warnings.warn(message=_unused_warn) |