diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 36852b93..2b548f22 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -301,6 +301,7 @@ from .field import SetInputOrTargetException from .instance import Instance from .utils import _get_func_signature from .utils import pretty_table_printer +from prettytable import PrettyTable class DataSet(object): @@ -425,6 +426,62 @@ class DataSet(object): def __repr__(self): return str(pretty_table_printer(self)) + def print_field_meta(self): + """ + 输出当前field的meta信息, 形似下列的输出 + + +-------------+-------+-------+ + | field_names | x | y | + +-------------+-------+-------+ + | is_input | True | False | + | is_target | False | False | + | ignore_type | False | | + | pad_value | 0 | | + +-------------+-------+-------+ + + field_names: DataSet中field的名称 + is_input: field是否为input + is_target: field是否为target + ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 + pad_value: 该field的pad的值,仅在该field为input或target时有意义 + + :return: + """ + if len(self.field_arrays)>0: + field_names = ['field_names'] + is_inputs = ['is_input'] + is_targets = ['is_target'] + pad_values = ['pad_value'] + ignore_types = ['ignore_type'] + + for name, field_array in self.field_arrays.items(): + field_names.append(name) + if field_array.is_input: + is_inputs.append(True) + else: + is_inputs.append(False) + if field_array.is_target: + is_targets.append(True) + else: + is_targets.append(False) + + if (field_array.is_input or field_array.is_target) and field_array.padder is not None: + pad_values.append(field_array.padder.get_pad_val()) + else: + pad_values.append(' ') + + if field_array._ignore_type: + ignore_types.append(True) + elif field_array.is_input or field_array.is_target: + ignore_types.append(False) + else: + ignore_types.append(' ') + table = PrettyTable(field_names=field_names) + fields = [is_inputs, is_targets, ignore_types, pad_values] + for field in fields: + table.add_row(field) + logger.info(table) + def append(self, instance): """ 将一个instance对象append到DataSet后面。 diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 82fcc523..1835bafa 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -53,7 +53,7 @@ class FieldArray: self.content = _content self._ignore_type = ignore_type # 根据input的情况设置input,target等 - self._cell_ndim = None # 多少维度 + self._cell_ndim = None # 多少维度, 如果value是1, dim为0; 如果value是[1, 2], dim=2 self.dtype = None # 最内层的element都是什么类型的 self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self._is_input = False @@ -484,7 +484,10 @@ class Padder: def set_pad_val(self, pad_val): self.pad_val = pad_val - + + def get_pad_val(self): + return self.pad_val + @abstractmethod def __call__(self, contents, field_name, field_ele_dtype, dim: int): """ diff --git a/requirements.txt b/requirements.txt index f71e2223..bdd4a9e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ tqdm>=4.28.1 nltk>=3.4.1 requests spacy +prettytable>=0.7.2 \ No newline at end of file diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 059d52d2..9820eff6 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -229,4 +229,17 @@ class TestDataSetIter(unittest.TestCase): def test__repr__(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: - self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}") + self.assertEqual(iter.__repr__(), """+--------------+--------+ +| x | y | ++--------------+--------+ +| [1, 2, 3, 4] | [5, 6] | ++--------------+--------+""") + + +class TestDataSetFieldMeta(unittest.TestCase): + def test_print_field_meta(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.print_field_meta() + + ds.set_input('x') + ds.print_field_meta()