Browse Source

DataSet中增加print_field_meta方法,使得其可以获取field的input和target信息

tags/v0.4.10
yh_cc 5 years ago
parent
commit
8c8e22cc9b
4 changed files with 77 additions and 3 deletions
  1. +57
    -0
      fastNLP/core/dataset.py
  2. +5
    -2
      fastNLP/core/field.py
  3. +1
    -0
      requirements.txt
  4. +14
    -1
      test/core/test_dataset.py

+ 57
- 0
fastNLP/core/dataset.py View File

@@ -301,6 +301,7 @@ from .field import SetInputOrTargetException
from .instance import Instance from .instance import Instance
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import pretty_table_printer from .utils import pretty_table_printer
from prettytable import PrettyTable




class DataSet(object): class DataSet(object):
@@ -425,6 +426,62 @@ class DataSet(object):
def __repr__(self): def __repr__(self):
return str(pretty_table_printer(self)) return str(pretty_table_printer(self))


def print_field_meta(self):
"""
输出当前field的meta信息, 形似下列的输出
+-------------+-------+-------+
| field_names | x | y |
+-------------+-------+-------+
| is_input | True | False |
| is_target | False | False |
| ignore_type | False | |
| pad_value | 0 | |
+-------------+-------+-------+

field_names: DataSet中field的名称
is_input: field是否为input
is_target: field是否为target
ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义
pad_value: 该field的pad的值,仅在该field为input或target时有意义

:return:
"""
if len(self.field_arrays)>0:
field_names = ['field_names']
is_inputs = ['is_input']
is_targets = ['is_target']
pad_values = ['pad_value']
ignore_types = ['ignore_type']

for name, field_array in self.field_arrays.items():
field_names.append(name)
if field_array.is_input:
is_inputs.append(True)
else:
is_inputs.append(False)
if field_array.is_target:
is_targets.append(True)
else:
is_targets.append(False)

if (field_array.is_input or field_array.is_target) and field_array.padder is not None:
pad_values.append(field_array.padder.get_pad_val())
else:
pad_values.append(' ')

if field_array._ignore_type:
ignore_types.append(True)
elif field_array.is_input or field_array.is_target:
ignore_types.append(False)
else:
ignore_types.append(' ')
table = PrettyTable(field_names=field_names)
fields = [is_inputs, is_targets, ignore_types, pad_values]
for field in fields:
table.add_row(field)
logger.info(table)

def append(self, instance): def append(self, instance):
""" """
将一个instance对象append到DataSet后面。 将一个instance对象append到DataSet后面。


+ 5
- 2
fastNLP/core/field.py View File

@@ -53,7 +53,7 @@ class FieldArray:
self.content = _content self.content = _content
self._ignore_type = ignore_type self._ignore_type = ignore_type
# 根据input的情况设置input,target等 # 根据input的情况设置input,target等
self._cell_ndim = None # 多少维度
self._cell_ndim = None # 多少维度, 如果value是1, dim为0; 如果value是[1, 2], dim=2
self.dtype = None # 最内层的element都是什么类型的 self.dtype = None # 最内层的element都是什么类型的
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self._is_input = False self._is_input = False
@@ -484,7 +484,10 @@ class Padder:
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
self.pad_val = pad_val self.pad_val = pad_val

def get_pad_val(self):
return self.pad_val

@abstractmethod @abstractmethod
def __call__(self, contents, field_name, field_ele_dtype, dim: int): def __call__(self, contents, field_name, field_ele_dtype, dim: int):
""" """


+ 1
- 0
requirements.txt View File

@@ -4,3 +4,4 @@ tqdm>=4.28.1
nltk>=3.4.1 nltk>=3.4.1
requests requests
spacy spacy
prettytable>=0.7.2

+ 14
- 1
test/core/test_dataset.py View File

@@ -229,4 +229,17 @@ class TestDataSetIter(unittest.TestCase):
def test__repr__(self): def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds: for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}")
self.assertEqual(iter.__repr__(), """+--------------+--------+
| x | y |
+--------------+--------+
| [1, 2, 3, 4] | [5, 6] |
+--------------+--------+""")


class TestDataSetFieldMeta(unittest.TestCase):
def test_print_field_meta(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.print_field_meta()

ds.set_input('x')
ds.print_field_meta()

Loading…
Cancel
Save