diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 82e3d07c..efc2ef7e 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -3,9 +3,9 @@ import inspect import os from collections import Counter from collections import namedtuple -from collections import defaultdict import torch +CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) def save_pickle(obj, pickle_path, file_name): """Save an object into a pickle file. @@ -121,14 +121,11 @@ def _check_arg_dict_list(func, args): missing = list(require_args - input_args) unused = list(input_args - all_args) - check_res = {} - check_res['missing'] = missing - check_res['unused'] = unused - check_res['duplicated'] = duplicated - check_res['required'] = list(require_args) - check_res['all_needed'] = list(all_args) - - return check_res + return CheckRes(missing=missing, + unused=unused, + duplicated=duplicated, + required=list(require_args), + all_needed=list(all_args)) def get_func_signature(func): """ @@ -165,6 +162,19 @@ def get_func_signature(func): signature_str = func.__name__ + signature_str return signature_str +def _is_function_or_method(func): + """ + + :param func: + :return: + """ + if not inspect.ismethod(func) and not inspect.isfunction(func): + return False + return True + +def _check_function_or_method(func): + if not _is_function_or_method(func): + raise TypeError(f"{type(func)} is not a method or function.") def _syn_model_data(model, *args): """ @@ -204,3 +214,19 @@ def _move_dict_value_to_device(device, *args): else: raise TypeError("Only support `dict` type right now.") + +class CheckError(Exception): + """ + + CheckError. Used in losses.LossBase, metrics.MetricBase. + """ + def __init__(self, check_res): + err = '' + if check_res['missing']: + err += f"Missing: {check_res['missing']}\n" + if check_res['duplicated']: + err += f"Duplicated: {check_res['duplicated']}\n" + if check_res['unused']: + err += f"Unused: {check_res['unused']}\n" + Exception.__init__(self, err) + self.check_res = check_res