diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index c9a89f90..b672be77 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -45,25 +45,38 @@ def pickle_exist(pickle_path, pickle_name): else: return False -def build_args(func, kwargs): - assert isinstance(func, function) and isinstance(kwargs, dict) +def build_args(func, **kwargs): spect = inspect.getfullargspec(func) - assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) + if spect.varkw is not None: + return kwargs needed_args = set(spect.args) - output = {name: default for name, default in zip(reversed(spect.args), reversed(spect.defaults))} + start_idx = len(spect.args) - len(spect.defaults) + output = {name: default for name, default in zip(spect.args[start_idx:], spect.defaults)} output.update({name: val for name, val in kwargs.items() if name in needed_args}) - if spect.varkw is not None: - output.update(kwargs) - - -# check miss args -def check_arg_dict(func, arg_dict): - pass - -def check_arg_dict_list(func, arg_dict_list): - pass - -def check_code(): - pass + return output +from collections import namedtuple, Counter +CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) +# check args +def check_arg_dict_list(func, args): + if isinstance(args, dict): + arg_dict_list = [args] + else: + arg_dict_list = args + assert callable(func) and isinstance(arg_dict_list, (list, tuple)) + assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) + spect = inspect.getfullargspec(func) + assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) + all_args = set(spect.args) + start_idx = len(spect.args) - len(spect.defaults) + default_args = set(spect.args[start_idx:]) + require_args = all_args - default_args + input_arg_count = Counter() + for arg_dict in arg_dict_list: + input_arg_count.update(arg_dict.keys()) + duplicated = [name for name, val in input_arg_count.items() if val > 1] + input_args = set(input_arg_count.keys()) + missing = list(require_args - input_args) + unused = list(input_args - all_args) + return CheckRes(missing=missing, unused=unused, duplicated=duplicated)