@@ -1,6 +1,7 @@ | |||
import warnings | |||
import inspect | |||
from collections import defaultdict | |||
import numpy as np | |||
import torch | |||
@@ -21,6 +22,7 @@ class MetricBase(object): | |||
def _init_param_map(self, key_map, **kwargs): | |||
self.param_map = {} | |||
value_counter = defaultdict(0) | |||
for key, value in key_map.items(): | |||
if isinstance(key, str): | |||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
@@ -32,16 +34,19 @@ class MetricBase(object): | |||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
def __call__(self, output_dict, target_dict, force_check=False): | |||
def __call__(self, output_dict, target_dict, check=False): | |||
""" | |||
:param output_dict: | |||
:param target_dict: | |||
:param check: boolean, | |||
:return: | |||
""" | |||
if not callable(self.evaluate): | |||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
if not self._checked: | |||
# 0. check param_map does not have same value | |||
# 1. check consistence between signature and param_map | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
func_args = func_spect.args | |||
@@ -65,7 +70,7 @@ class MetricBase(object): | |||
mapped_target_dict[func_arg] = target_dict[input_arg] | |||
# check duplicated, unused, missing | |||
if force_check or not self._checked: | |||
if check or not self._checked: | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | |||
self._reverse_param_map = {value:key for key, value in check_res.items()} | |||
for key, value in check_res.items(): | |||
@@ -73,8 +78,9 @@ class MetricBase(object): | |||
for idx, func_param in enumerate(value): | |||
if func_param in self._reverse_param_map: | |||
new_value[idx] = self._reverse_param_map[func_param] | |||
if check_res.missing or check_res.duplicated: | |||
raise CheckError(check_res=check_res) | |||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||
raise CheckError(check_res=check_res, | |||
func_signature=get_func_signature(self.evaluate)) | |||
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) | |||
metrics = self.evaluate(**refined_args) | |||
@@ -92,7 +98,6 @@ class Metric(MetricBase): | |||
super().__init__() | |||
pass | |||
def _prepare_metrics(metrics): | |||
""" | |||
@@ -12,6 +12,7 @@ from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _check_loss_evaluate | |||
class Tester(object): | |||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
@@ -47,7 +48,6 @@ class Tester(object): | |||
self._model_device = model.parameters().__next__().device | |||
def test(self): | |||
# turn on the testing mode; clean up the history | |||
network = self._model | |||
@@ -75,7 +75,9 @@ class Tester(object): | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
except CheckError as e: | |||
pass | |||
prev_func_signature = get_func_signature(self._predict_func) | |||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, output=output, batch_y=truths) | |||
if self.verbose >= 0: | |||
@@ -20,12 +20,11 @@ from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.losses import LossBase | |||
from fastNLP.core.metrics import MetricBase | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _check_loss_evaluate | |||
from fastNLP.core.utils import _check_forward_error | |||
class Trainer(object): | |||
"""Main Training Loop | |||
@@ -33,7 +32,7 @@ class Trainer(object): | |||
""" | |||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | |||
dev_data=None, use_cuda=False, save_path="./save", | |||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, | |||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0, | |||
**kwargs): | |||
super(Trainer, self).__init__() | |||
@@ -53,8 +52,9 @@ class Trainer(object): | |||
# prepare loss | |||
losser = _prepare_losser(losser) | |||
if need_check_code: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) | |||
if check_code_level>-1: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
check_level=check_code_level) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
@@ -250,13 +250,9 @@ class Trainer(object): | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
IGNORE_CHECK_LEVEL = 0 | |||
WARNING_CHECK_LEVEL = 1 | |||
STRICT_CHECK_LEVEL = 2 | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, | |||
check_level=WARNING_CHECK_LEVEL): | |||
check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = model.parameters().__next__().device | |||
@@ -265,7 +261,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
_move_dict_value_to_device(model_devcie, batch_x, batch_y) | |||
# forward check | |||
if batch_count==0: | |||
_check_forward_error(model_func=model.forward, check_level=check_level, | |||
_check_forward_error(forward_func=model.forward, check_level=check_level, | |||
batch_x=batch_x) | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
@@ -277,19 +273,21 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
# loss check | |||
try: | |||
loss = losser(output, batch_y) | |||
# check loss output | |||
if batch_count == 0: | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError( | |||
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||
f"but got `{type(loss)}`.") | |||
if len(loss.size()) != 0: | |||
raise ValueError( | |||
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " | |||
f"should be torch.size([])") | |||
loss.backward() | |||
except CheckError as e: | |||
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature, | |||
check_res=e.check_res, output=output, batch_y=batch_y, | |||
check_level=check_level) | |||
# check loss output | |||
if batch_count == 0: | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " | |||
f"but got `{type(loss)}`.") | |||
if len(loss.size())!=0: | |||
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " | |||
f"should be torch.size([])") | |||
loss.backward() | |||
model.zero_grad() | |||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||
break | |||
@@ -300,93 +298,5 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
tester.test() | |||
def _check_forward_error(model_func, check_level, batch_x): | |||
check_res = _check_arg_dict_list(model_func, batch_x) | |||
_missing = '' | |||
_unused = '' | |||
func_signature = get_func_signature(model_func) | |||
if len(check_res['missing'])!=0: | |||
_missing = "Function {} misses {}, only provided with {}, " \ | |||
".\n".format(func_signature, check_res.missing, | |||
list(batch_x.keys())) | |||
if len(check_res['unused'])!=0: | |||
if len(check_res.unused) > 1: | |||
_unused = "{} are not used ".format(check_res.unused) | |||
else: | |||
_unused = "{} is not used ".format(check_res.unused) | |||
_unused += "in function {}.\n".format(func_signature) | |||
if _missing: | |||
if len(_unused)>0 and STRICT_CHECK_LEVEL: | |||
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) | |||
else: | |||
_error_str = _missing | |||
# TODO 这里可能需要自定义一些Error类型 | |||
raise TypeError(_error_str) | |||
if _unused: | |||
if check_level == STRICT_CHECK_LEVEL: | |||
# TODO 这里可能需要自定义一些Error类型 | |||
raise ValueError(_unused) | |||
elif check_level == WARNING_CHECK_LEVEL: | |||
warnings.warn(message=_unused) | |||
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): | |||
_missing = '' | |||
_unused = '' | |||
_duplicated = '' | |||
func_signature = get_func_signature(func) | |||
prev_func_signature = get_func_signature(prev_func) | |||
if len(check_res.missing)>0: | |||
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ | |||
"{}(from target in Dataset)." \ | |||
.format(func_signature, check_res.missing, | |||
list(output.keys()), prev_func_signature, | |||
list(batch_y.keys())) | |||
if len(check_res.unused)>0: | |||
if len(check_res.unused) > 1: | |||
_unused = "{} are not used ".format(check_res.unused) | |||
else: | |||
_unused = "{} is not used ".format(check_res.unused) | |||
_unused += "in function {}.\n".format(func_signature) | |||
if len(check_res.duplicated)>0: | |||
if len(check_res.duplicated) > 1: | |||
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ | |||
"them in {} at the same time.".format(check_res.duplicated, | |||
func_signature, | |||
check_res.duplicated, | |||
prev_func_signature) | |||
else: | |||
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ | |||
"it in {} at the same time.".format(check_res.duplicated, | |||
func_signature, | |||
check_res.duplicated, | |||
prev_func_signature) | |||
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) | |||
if _number_errs > 0: | |||
_error_strs = [] | |||
if _number_errs > 1: | |||
count = 0 | |||
order_words = ['Firstly', 'Secondly', 'Thirdly'] | |||
if _missing: | |||
_error_strs.append('{}, {}'.format(order_words[count], _missing)) | |||
count += 1 | |||
if _duplicated: | |||
_error_strs.append('{}, {}'.format(order_words[count], _duplicated)) | |||
count += 1 | |||
if _unused and check_level == STRICT_CHECK_LEVEL: | |||
_error_strs.append('{}, {}'.format(order_words[count], _unused)) | |||
else: | |||
if _unused: | |||
if check_level == STRICT_CHECK_LEVEL: | |||
# TODO 这里可能需要自定义一些Error类型 | |||
_error_strs.append(_unused) | |||
elif check_level == WARNING_CHECK_LEVEL: | |||
_unused = _unused.strip() | |||
warnings.warn(_unused) | |||
else: | |||
if _missing: | |||
_error_strs.append(_missing) | |||
if _duplicated: | |||
_error_strs.append(_duplicated) | |||
if _error_strs: | |||
raise ValueError('\n' + '\n'.join(_error_strs)) | |||
@@ -1,11 +1,14 @@ | |||
import _pickle | |||
import inspect | |||
import os | |||
import warnings | |||
from collections import Counter | |||
from collections import namedtuple | |||
import torch | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs'], verbose=False) | |||
def save_pickle(obj, pickle_path, file_name): | |||
"""Save an object into a pickle file. | |||
@@ -105,7 +108,6 @@ def _check_arg_dict_list(func, args): | |||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | |||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | |||
spect = inspect.getfullargspec(func) | |||
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) | |||
all_args = set([arg for arg in spect.args if arg!='self']) | |||
defaults = [] | |||
if spect.defaults is not None: | |||
@@ -125,7 +127,8 @@ def _check_arg_dict_list(func, args): | |||
unused=unused, | |||
duplicated=duplicated, | |||
required=list(require_args), | |||
all_needed=list(all_args)) | |||
all_needed=list(all_args), | |||
varargs=[arg for arg in spect.varargs]) | |||
def get_func_signature(func): | |||
""" | |||
@@ -221,15 +224,73 @@ class CheckError(Exception): | |||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res:CheckRes, func_signature:str): | |||
err = '' | |||
errs = [f'The following problems occurred when calling {func_signature}'] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
if check_res.missing: | |||
err += f"Missing: {check_res.missing}\n" | |||
errs.append(f"\tmissing param: {check_res.missing}") | |||
if check_res.duplicated: | |||
err += f"Duplicated: {check_res.duplicated}\n" | |||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||
if check_res.unused: | |||
err += f"Unused: {check_res.unused}\n" | |||
errs.append(f"\tunused param: {check_res.unused}") | |||
Exception.__init__(self, err) | |||
Exception.__init__(self, '\n'.join(errs)) | |||
self.check_res = check_res | |||
self.func_signature = func_signature | |||
IGNORE_CHECK_LEVEL = 0 | |||
WARNING_CHECK_LEVEL = 1 | |||
STRICT_CHECK_LEVEL = 2 | |||
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, | |||
output:dict, batch_y:dict, check_level=0): | |||
errs = [] | |||
_unused = [] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " | |||
f"please delete it.)") | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}" | |||
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).") | |||
if check_res.duplicated: | |||
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " | |||
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") | |||
if check_res.unused: | |||
_unused = [f"\tunused param: {check_res.unused}"] | |||
if check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(_unused) | |||
if len(errs)>0: | |||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||
raise NameError('\n'.join(errs)) | |||
if _unused: | |||
if check_level == WARNING_CHECK_LEVEL: | |||
_unused_warn = _unused[0] + f' in {func_signature}.' | |||
warnings.warn(message=_unused_warn) | |||
def _check_forward_error(forward_func, batch_x, check_level): | |||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||
func_signature = get_func_signature(forward_func) | |||
errs = [] | |||
_unused = [] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.") | |||
if check_res.unused: | |||
_unused = [f"\tunused param: {check_res.unused}"] | |||
if check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(_unused) | |||
if len(errs)>0: | |||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||
raise NameError('\n'.join(errs)) | |||
if _unused: | |||
if check_level == WARNING_CHECK_LEVEL: | |||
_unused_warn = _unused[0] + f' in {func_signature}.' | |||
warnings.warn(message=_unused_warn) |