Browse Source

add args check & build function

tags/v0.2.0
yunfan 6 years ago
parent
commit
cbf54c1918
1 changed files with 30 additions and 17 deletions
  1. +30
    -17
      fastNLP/core/utils.py

+ 30
- 17
fastNLP/core/utils.py View File

@@ -45,25 +45,38 @@ def pickle_exist(pickle_path, pickle_name):
else:
return False

def build_args(func, kwargs):
assert isinstance(func, function) and isinstance(kwargs, dict)
def build_args(func, **kwargs):
spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
if spect.varkw is not None:
return kwargs
needed_args = set(spect.args)
output = {name: default for name, default in zip(reversed(spect.args), reversed(spect.defaults))}
start_idx = len(spect.args) - len(spect.defaults)
output = {name: default for name, default in zip(spect.args[start_idx:], spect.defaults)}
output.update({name: val for name, val in kwargs.items() if name in needed_args})
if spect.varkw is not None:
output.update(kwargs)


# check miss args
def check_arg_dict(func, arg_dict):
pass

def check_arg_dict_list(func, arg_dict_list):
pass

def check_code():
pass
return output

from collections import namedtuple, Counter
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True)

# check args
def check_arg_dict_list(func, args):
if isinstance(args, dict):
arg_dict_list = [args]
else:
arg_dict_list = args
assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
all_args = set(spect.args)
start_idx = len(spect.args) - len(spect.defaults)
default_args = set(spect.args[start_idx:])
require_args = all_args - default_args
input_arg_count = Counter()
for arg_dict in arg_dict_list:
input_arg_count.update(arg_dict.keys())
duplicated = [name for name, val in input_arg_count.items() if val > 1]
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
return CheckRes(missing=missing, unused=unused, duplicated=duplicated)

Loading…
Cancel
Save