Browse Source

add ignore_type in DataSet.add_field

tags/v0.4.10
FengZiYjun 5 years ago
parent
commit
3e01b14249
4 changed files with 55 additions and 37 deletions
  1. +3
    -2
      fastNLP/core/dataset.py
  2. +39
    -34
      fastNLP/core/fieldarray.py
  3. +6
    -1
      test/core/test_dataset.py
  4. +7
    -0
      test/core/test_fieldarray.py

+ 3
- 2
fastNLP/core/dataset.py View File

@@ -157,7 +157,7 @@ class DataSet(object):
assert name in self.field_arrays assert name in self.field_arrays
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False):
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False):
"""Add a new field to the DataSet. """Add a new field to the DataSet.
:param str name: the name of the field. :param str name: the name of the field.
@@ -165,13 +165,14 @@ class DataSet(object):
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 :param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可
:param bool is_input: whether this field is model input. :param bool is_input: whether this field is model input.
:param bool is_target: whether this field is label or target. :param bool is_target: whether this field is label or target.
:param bool ignore_type: If True, do not perform type check. (Default: False)
""" """
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
if len(self) != len(fields): if len(self) != len(fields):
raise RuntimeError(f"The field to append must have the same size as dataset. " raise RuntimeError(f"The field to append must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fields)}") f"Dataset size {len(self)} != field size {len(fields)}")
self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input,
padder=padder)
padder=padder, ignore_type=ignore_type)


def delete_field(self, name): def delete_field(self, name):
"""Delete a field based on the field name. """Delete a field based on the field name.


+ 39
- 34
fastNLP/core/fieldarray.py View File

@@ -96,10 +96,11 @@ class FieldArray(object):
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input. :param bool is_input: If True, this FieldArray is used to the model input.
:param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding)
:param PadderBase padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding)
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False)
""" """


def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)):
def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0), ignore_type=False):
"""DataSet在初始化时会有两类方法对FieldArray操作: """DataSet在初始化时会有两类方法对FieldArray操作:
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -114,6 +115,7 @@ class FieldArray(object):
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])


类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。
ignore_type用来控制是否进行类型检查,如果为True,则不检查。


""" """
self.name = name self.name = name
@@ -136,6 +138,7 @@ class FieldArray(object):
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐
self.content_dim = None # 表示content是多少维的list self.content_dim = None # 表示content是多少维的list
self.set_padder(padder) self.set_padder(padder)
self.ignore_type = ignore_type


self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array


@@ -149,8 +152,9 @@ class FieldArray(object):
self.is_target = is_target self.is_target = is_target


def _set_dtype(self): def _set_dtype(self):
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)
if self.ignore_type is False:
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)


@property @property
def is_input(self): def is_input(self):
@@ -278,39 +282,40 @@ class FieldArray(object):


:param val: int, float, str, or a list of one. :param val: int, float, str, or a list of one.
""" """
if isinstance(val, list):
pass
elif isinstance(val, tuple): # 确保最外层是list
val = list(val)
elif isinstance(val, np.ndarray):
val = val.tolist()
elif any((isinstance(val, t) for t in self.BASIC_TYPES)):
pass
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))

if self.is_input is True or self.is_target is True:
if type(val) == list:
if len(val) == 0:
raise ValueError("Cannot append an empty list.")
if self.content_dim == 2 and self._1d_list_check(val):
# 1维list检查
pass
elif self.content_dim == 3 and self._2d_list_check(val):
# 2维list检查
pass
else:
raise RuntimeError(
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val))
elif type(val) in self.BASIC_TYPES and self.content_dim == 1:
# scalar检查
if type(val) == float and self.pytype == int:
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
if self.ignore_type is False:
if isinstance(val, list):
pass
elif isinstance(val, tuple): # 确保最外层是list
val = list(val)
elif isinstance(val, np.ndarray):
val = val.tolist()
elif any((isinstance(val, t) for t in self.BASIC_TYPES)):
pass
else: else:
raise RuntimeError( raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))

if self.is_input is True or self.is_target is True:
if type(val) == list:
if len(val) == 0:
raise ValueError("Cannot append an empty list.")
if self.content_dim == 2 and self._1d_list_check(val):
# 1维list检查
pass
elif self.content_dim == 3 and self._2d_list_check(val):
# 2维list检查
pass
else:
raise RuntimeError(
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val))
elif type(val) in self.BASIC_TYPES and self.content_dim == 1:
# scalar检查
if type(val) == float and self.pytype == int:
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
self.content.append(val) self.content.append(val)


def __getitem__(self, indices): def __getitem__(self, indices):


+ 6
- 1
test/core/test_dataset.py View File

@@ -52,7 +52,7 @@ class TestDataSetMethods(unittest.TestCase):
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)


def test_add_append(self):
def test_add_field(self):
dd = DataSet() dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10)
@@ -65,6 +65,11 @@ class TestDataSetMethods(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40) dd.add_field("??", [[1, 2]] * 40)


def test_add_field_ignore_type(self):
dd = DataSet()
dd.add_field("x", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], ignore_type=True, is_target=True)
dd.add_field("y", [{1, "1"}, {2, "2"}, {3, "3"}, {4, "4"}], ignore_type=True, is_target=True)

def test_delete_field(self): def test_delete_field(self):
dd = DataSet() dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)


+ 7
- 0
test/core/test_fieldarray.py View File

@@ -155,6 +155,13 @@ class TestFieldArray(unittest.TestCase):
self.assertEqual(len(fa), 3) self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])


def test_ignore_type(self):
# 测试新添加的参数ignore_type,用来跳过类型检查
fa = FieldArray("y", [[1.1, 2.2, "jin", {}, "hahah"], [int, 2, "$", 4, 5]], is_input=True, ignore_type=True)
fa.append([1.2, 2.3, str, 4.5, print])

fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True)



class TestPadder(unittest.TestCase): class TestPadder(unittest.TestCase):




Loading…
Cancel
Save