diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 60e0d82f..69bb540d 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -1,6 +1,7 @@ import warnings import inspect +from collections import defaultdict import numpy as np import torch @@ -21,6 +22,7 @@ class MetricBase(object): def _init_param_map(self, key_map, **kwargs): self.param_map = {} + value_counter = defaultdict(0) for key, value in key_map.items(): if isinstance(key, str): raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") @@ -32,16 +34,19 @@ class MetricBase(object): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") self.param_map[key] = value - def __call__(self, output_dict, target_dict, force_check=False): + def __call__(self, output_dict, target_dict, check=False): """ :param output_dict: :param target_dict: + :param check: boolean, :return: """ if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") if not self._checked: + # 0. check param_map does not have same value + # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = func_spect.args @@ -65,7 +70,7 @@ class MetricBase(object): mapped_target_dict[func_arg] = target_dict[input_arg] # check duplicated, unused, missing - if force_check or not self._checked: + if check or not self._checked: check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) self._reverse_param_map = {value:key for key, value in check_res.items()} for key, value in check_res.items(): @@ -73,8 +78,9 @@ class MetricBase(object): for idx, func_param in enumerate(value): if func_param in self._reverse_param_map: new_value[idx] = self._reverse_param_map[func_param] - if check_res.missing or check_res.duplicated: - raise CheckError(check_res=check_res) + if check_res.missing or check_res.duplicated or check_res.varargs: + raise CheckError(check_res=check_res, + func_signature=get_func_signature(self.evaluate)) refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) metrics = self.evaluate(**refined_args) @@ -92,7 +98,6 @@ class Metric(MetricBase): super().__init__() pass - def _prepare_metrics(metrics): """ diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 33d8cc81..39efb454 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -12,6 +12,7 @@ from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.utils import CheckError +from fastNLP.core.utils import _check_loss_evaluate class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ @@ -47,7 +48,6 @@ class Tester(object): self._model_device = model.parameters().__next__().device - def test(self): # turn on the testing mode; clean up the history network = self._model @@ -75,7 +75,9 @@ class Tester(object): metric_name = metric.__class__.__name__ eval_results[metric_name] = eval_result except CheckError as e: - pass + prev_func_signature = get_func_signature(self._predict_func) + _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, + check_res=e.check_res, output=output, batch_y=truths) if self.verbose >= 0: diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index da8e54f9..acbcb586 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -20,12 +20,11 @@ from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature from fastNLP.core.dataset import DataSet - -from fastNLP.core.losses import LossBase -from fastNLP.core.metrics import MetricBase from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.utils import CheckError +from fastNLP.core.utils import _check_loss_evaluate +from fastNLP.core.utils import _check_forward_error class Trainer(object): """Main Training Loop @@ -33,7 +32,7 @@ class Trainer(object): """ def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", - optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, + optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0, **kwargs): super(Trainer, self).__init__() @@ -53,8 +52,9 @@ class Trainer(object): # prepare loss losser = _prepare_losser(losser) - if need_check_code: - _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) + if check_code_level>-1: + _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, + check_level=check_code_level) self.train_data = train_data self.dev_data = dev_data # If None, No validation. @@ -250,13 +250,9 @@ class Trainer(object): DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 -IGNORE_CHECK_LEVEL = 0 -WARNING_CHECK_LEVEL = 1 -STRICT_CHECK_LEVEL = 2 - def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, - check_level=WARNING_CHECK_LEVEL): + check_level=0): # check get_loss 方法 model_devcie = model.parameters().__next__().device @@ -265,7 +261,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ _move_dict_value_to_device(model_devcie, batch_x, batch_y) # forward check if batch_count==0: - _check_forward_error(model_func=model.forward, check_level=check_level, + _check_forward_error(forward_func=model.forward, check_level=check_level, batch_x=batch_x) refined_batch_x = _build_args(model.forward, **batch_x) @@ -277,19 +273,21 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ # loss check try: loss = losser(output, batch_y) + # check loss output + if batch_count == 0: + if not isinstance(loss, torch.Tensor): + raise TypeError( + f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " + f"but got `{type(loss)}`.") + if len(loss.size()) != 0: + raise ValueError( + f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " + f"should be torch.size([])") + loss.backward() except CheckError as e: _check_loss_evaluate(prev_func=model.forward, func=e.func_signature, check_res=e.check_res, output=output, batch_y=batch_y, check_level=check_level) - # check loss output - if batch_count == 0: - if not isinstance(loss, torch.Tensor): - raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " - f"but got `{type(loss)}`.") - if len(loss.size())!=0: - raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " - f"should be torch.size([])") - loss.backward() model.zero_grad() if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: break @@ -300,93 +298,5 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ tester.test() -def _check_forward_error(model_func, check_level, batch_x): - check_res = _check_arg_dict_list(model_func, batch_x) - _missing = '' - _unused = '' - func_signature = get_func_signature(model_func) - if len(check_res['missing'])!=0: - _missing = "Function {} misses {}, only provided with {}, " \ - ".\n".format(func_signature, check_res.missing, - list(batch_x.keys())) - if len(check_res['unused'])!=0: - if len(check_res.unused) > 1: - _unused = "{} are not used ".format(check_res.unused) - else: - _unused = "{} is not used ".format(check_res.unused) - _unused += "in function {}.\n".format(func_signature) - if _missing: - if len(_unused)>0 and STRICT_CHECK_LEVEL: - _error_str = "(1).{}\n(2).{}".format(_missing, _unused) - else: - _error_str = _missing - # TODO 这里可能需要自定义一些Error类型 - raise TypeError(_error_str) - if _unused: - if check_level == STRICT_CHECK_LEVEL: - # TODO 这里可能需要自定义一些Error类型 - raise ValueError(_unused) - elif check_level == WARNING_CHECK_LEVEL: - warnings.warn(message=_unused) - -def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): - _missing = '' - _unused = '' - _duplicated = '' - func_signature = get_func_signature(func) - prev_func_signature = get_func_signature(prev_func) - if len(check_res.missing)>0: - _missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ - "{}(from target in Dataset)." \ - .format(func_signature, check_res.missing, - list(output.keys()), prev_func_signature, - list(batch_y.keys())) - if len(check_res.unused)>0: - if len(check_res.unused) > 1: - _unused = "{} are not used ".format(check_res.unused) - else: - _unused = "{} is not used ".format(check_res.unused) - _unused += "in function {}.\n".format(func_signature) - if len(check_res.duplicated)>0: - if len(check_res.duplicated) > 1: - _duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ - "them in {} at the same time.".format(check_res.duplicated, - func_signature, - check_res.duplicated, - prev_func_signature) - else: - _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ - "it in {} at the same time.".format(check_res.duplicated, - func_signature, - check_res.duplicated, - prev_func_signature) - _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) - if _number_errs > 0: - _error_strs = [] - if _number_errs > 1: - count = 0 - order_words = ['Firstly', 'Secondly', 'Thirdly'] - if _missing: - _error_strs.append('{}, {}'.format(order_words[count], _missing)) - count += 1 - if _duplicated: - _error_strs.append('{}, {}'.format(order_words[count], _duplicated)) - count += 1 - if _unused and check_level == STRICT_CHECK_LEVEL: - _error_strs.append('{}, {}'.format(order_words[count], _unused)) - else: - if _unused: - if check_level == STRICT_CHECK_LEVEL: - # TODO 这里可能需要自定义一些Error类型 - _error_strs.append(_unused) - elif check_level == WARNING_CHECK_LEVEL: - _unused = _unused.strip() - warnings.warn(_unused) - else: - if _missing: - _error_strs.append(_missing) - if _duplicated: - _error_strs.append(_duplicated) - if _error_strs: - raise ValueError('\n' + '\n'.join(_error_strs)) + diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 61c5bc5c..d237c190 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,11 +1,14 @@ import _pickle import inspect import os +import warnings from collections import Counter from collections import namedtuple import torch -CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) + +CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', + 'varargs'], verbose=False) def save_pickle(obj, pickle_path, file_name): """Save an object into a pickle file. @@ -105,7 +108,6 @@ def _check_arg_dict_list(func, args): assert callable(func) and isinstance(arg_dict_list, (list, tuple)) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) spect = inspect.getfullargspec(func) - assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) all_args = set([arg for arg in spect.args if arg!='self']) defaults = [] if spect.defaults is not None: @@ -125,7 +127,8 @@ def _check_arg_dict_list(func, args): unused=unused, duplicated=duplicated, required=list(require_args), - all_needed=list(all_args)) + all_needed=list(all_args), + varargs=[arg for arg in spect.varargs]) def get_func_signature(func): """ @@ -221,15 +224,73 @@ class CheckError(Exception): CheckError. Used in losses.LossBase, metrics.MetricBase. """ def __init__(self, check_res:CheckRes, func_signature:str): - err = '' + errs = [f'The following problems occurred when calling {func_signature}'] + + if check_res.varargs: + errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") if check_res.missing: - err += f"Missing: {check_res.missing}\n" + errs.append(f"\tmissing param: {check_res.missing}") if check_res.duplicated: - err += f"Duplicated: {check_res.duplicated}\n" + errs.append(f"\tduplicated param: {check_res.duplicated}") if check_res.unused: - err += f"Unused: {check_res.unused}\n" + errs.append(f"\tunused param: {check_res.unused}") - Exception.__init__(self, err) + Exception.__init__(self, '\n'.join(errs)) self.check_res = check_res self.func_signature = func_signature + +IGNORE_CHECK_LEVEL = 0 +WARNING_CHECK_LEVEL = 1 +STRICT_CHECK_LEVEL = 2 + +def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, + output:dict, batch_y:dict, check_level=0): + errs = [] + _unused = [] + if check_res.varargs: + errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " + f"please delete it.)") + if check_res.missing: + errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}" + f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).") + if check_res.duplicated: + errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " + f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") + if check_res.unused: + _unused = [f"\tunused param: {check_res.unused}"] + if check_level == STRICT_CHECK_LEVEL: + errs.extend(_unused) + + if len(errs)>0: + errs.insert(0, f'The following problems occurred when calling {func_signature}') + raise NameError('\n'.join(errs)) + if _unused: + if check_level == WARNING_CHECK_LEVEL: + _unused_warn = _unused[0] + f' in {func_signature}.' + warnings.warn(message=_unused_warn) + + +def _check_forward_error(forward_func, batch_x, check_level): + check_res = _check_arg_dict_list(forward_func, batch_x) + func_signature = get_func_signature(forward_func) + + errs = [] + _unused = [] + + if check_res.varargs: + errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") + if check_res.missing: + errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.") + if check_res.unused: + _unused = [f"\tunused param: {check_res.unused}"] + if check_level == STRICT_CHECK_LEVEL: + errs.extend(_unused) + + if len(errs)>0: + errs.insert(0, f'The following problems occurred when calling {func_signature}') + raise NameError('\n'.join(errs)) + if _unused: + if check_level == WARNING_CHECK_LEVEL: + _unused_warn = _unused[0] + f' in {func_signature}.' + warnings.warn(message=_unused_warn) \ No newline at end of file