Browse Source

update check_args and add Dataset get_input/target_name

tags/v0.2.0
yunfan 6 years ago
parent
commit
c7923c82e7
2 changed files with 12 additions and 2 deletions
  1. +6
    -0
      fastNLP/core/dataset.py
  2. +6
    -2
      fastNLP/core/utils.py

+ 6
- 0
fastNLP/core/dataset.py View File

@@ -200,6 +200,12 @@ class DataSet(object):
raise KeyError("{} is not a valid field name.".format(name))
return self

def get_input_name(self):
return [name for name, field in self.field_arrays.items() if field.is_input]

def get_target_name(self):
return [name for name, field in self.field_arrays.items() if field.is_target]

def __getattr__(self, item):
if item in self.field_arrays:
return self.field_arrays[item]


+ 6
- 2
fastNLP/core/utils.py View File

@@ -4,7 +4,7 @@ import inspect
from collections import namedtuple
from collections import Counter

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True)
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=True)


def save_pickle(obj, pickle_path, file_name):
@@ -82,4 +82,8 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
return CheckRes(missing=missing, unused=unused, duplicated=duplicated)
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))

Loading…
Cancel
Save