From 3e01b142490c696642226c77705f4994a7418019 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 25 Feb 2019 10:26:03 +0800 Subject: [PATCH] add ignore_type in DataSet.add_field --- fastNLP/core/dataset.py | 5 ++- fastNLP/core/fieldarray.py | 73 +++++++++++++++++++----------------- test/core/test_dataset.py | 7 +++- test/core/test_fieldarray.py | 7 ++++ 4 files changed, 55 insertions(+), 37 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 601fa589..f25e2cfd 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -157,7 +157,7 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) - def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False): + def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): """Add a new field to the DataSet. :param str name: the name of the field. @@ -165,13 +165,14 @@ class DataSet(object): :param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 :param bool is_input: whether this field is model input. :param bool is_target: whether this field is label or target. + :param bool ignore_type: If True, do not perform type check. (Default: False) """ if len(self.field_arrays) != 0: if len(self) != len(fields): raise RuntimeError(f"The field to append must have the same size as dataset. " f"Dataset size {len(self)} != field size {len(fields)}") self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, - padder=padder) + padder=padder, ignore_type=ignore_type) def delete_field(self, name): """Delete a field based on the field name. diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index f3fcb3c8..148dfc6c 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -96,10 +96,11 @@ class FieldArray(object): :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. :param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_input: If True, this FieldArray is used to the model input. - :param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) + :param PadderBase padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) + :param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False) """ - def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): + def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0), ignore_type=False): """DataSet在初始化时会有两类方法对FieldArray操作: 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) @@ -114,6 +115,7 @@ class FieldArray(object): 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 + ignore_type用来控制是否进行类型检查,如果为True,则不检查。 """ self.name = name @@ -136,6 +138,7 @@ class FieldArray(object): self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content_dim = None # 表示content是多少维的list self.set_padder(padder) + self.ignore_type = ignore_type self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array @@ -149,8 +152,9 @@ class FieldArray(object): self.is_target = is_target def _set_dtype(self): - self.pytype = self._type_detection(self.content) - self.dtype = self._map_to_np_type(self.pytype) + if self.ignore_type is False: + self.pytype = self._type_detection(self.content) + self.dtype = self._map_to_np_type(self.pytype) @property def is_input(self): @@ -278,39 +282,40 @@ class FieldArray(object): :param val: int, float, str, or a list of one. """ - if isinstance(val, list): - pass - elif isinstance(val, tuple): # 确保最外层是list - val = list(val) - elif isinstance(val, np.ndarray): - val = val.tolist() - elif any((isinstance(val, t) for t in self.BASIC_TYPES)): - pass - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - - if self.is_input is True or self.is_target is True: - if type(val) == list: - if len(val) == 0: - raise ValueError("Cannot append an empty list.") - if self.content_dim == 2 and self._1d_list_check(val): - # 1维list检查 - pass - elif self.content_dim == 3 and self._2d_list_check(val): - # 2维list检查 - pass - else: - raise RuntimeError( - "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) - elif type(val) in self.BASIC_TYPES and self.content_dim == 1: - # scalar检查 - if type(val) == float and self.pytype == int: - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) + if self.ignore_type is False: + if isinstance(val, list): + pass + elif isinstance(val, tuple): # 确保最外层是list + val = list(val) + elif isinstance(val, np.ndarray): + val = val.tolist() + elif any((isinstance(val, t) for t in self.BASIC_TYPES)): + pass else: raise RuntimeError( "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) + + if self.is_input is True or self.is_target is True: + if type(val) == list: + if len(val) == 0: + raise ValueError("Cannot append an empty list.") + if self.content_dim == 2 and self._1d_list_check(val): + # 1维list检查 + pass + elif self.content_dim == 3 and self._2d_list_check(val): + # 2维list检查 + pass + else: + raise RuntimeError( + "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) + elif type(val) in self.BASIC_TYPES and self.content_dim == 1: + # scalar检查 + if type(val) == float and self.pytype == int: + self.pytype = float + self.dtype = self._map_to_np_type(self.pytype) + else: + raise RuntimeError( + "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) self.content.append(val) def __getitem__(self, indices): diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 72ced912..231fedd0 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -52,7 +52,7 @@ class TestDataSetMethods(unittest.TestCase): self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) - def test_add_append(self): + def test_add_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10) @@ -65,6 +65,11 @@ class TestDataSetMethods(unittest.TestCase): with self.assertRaises(RuntimeError): dd.add_field("??", [[1, 2]] * 40) + def test_add_field_ignore_type(self): + dd = DataSet() + dd.add_field("x", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], ignore_type=True, is_target=True) + dd.add_field("y", [{1, "1"}, {2, "2"}, {3, "3"}, {4, "4"}], ignore_type=True, is_target=True) + def test_delete_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 151d9335..e3595f9a 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -155,6 +155,13 @@ class TestFieldArray(unittest.TestCase): self.assertEqual(len(fa), 3) self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) + def test_ignore_type(self): + # 测试新添加的参数ignore_type,用来跳过类型检查 + fa = FieldArray("y", [[1.1, 2.2, "jin", {}, "hahah"], [int, 2, "$", 4, 5]], is_input=True, ignore_type=True) + fa.append([1.2, 2.3, str, 4.5, print]) + + fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) + class TestPadder(unittest.TestCase):