diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 39af672c..550ef7d9 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -200,6 +200,12 @@ class DataSet(object): raise KeyError("{} is not a valid field name.".format(name)) return self + def get_input_name(self): + return [name for name, field in self.field_arrays.items() if field.is_input] + + def get_target_name(self): + return [name for name, field in self.field_arrays.items() if field.is_target] + def __getattr__(self, item): if item in self.field_arrays: return self.field_arrays[item] diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 6a284ab9..ca38e45e 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -4,7 +4,7 @@ import inspect from collections import namedtuple from collections import Counter -CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) +CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=True) def save_pickle(obj, pickle_path, file_name): @@ -82,4 +82,8 @@ def _check_arg_dict_list(func, args): input_args = set(input_arg_count.keys()) missing = list(require_args - input_args) unused = list(input_args - all_args) - return CheckRes(missing=missing, unused=unused, duplicated=duplicated) + return CheckRes(missing=missing, + unused=unused, + duplicated=duplicated, + required=list(require_args), + all_needed=list(all_args))