From 27833d06ae7ab67480e1b43df05ffbc092d86244 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 4 Dec 2018 16:13:20 +0800 Subject: [PATCH] FieldArray only check type when is_input or is_target is set. --- fastNLP/core/fieldarray.py | 110 +++++++++++++++++++++++------------ test/core/test_fieldarray.py | 23 ++++++++ test/core/test_metrics.py | 31 +++++----- 3 files changed, 111 insertions(+), 53 deletions(-) diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 0a94b26c..2340cd13 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -7,11 +7,11 @@ class FieldArray(object): """ - def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): + def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): """ :param str name: the name of the FieldArray - :param list content: a list of int, float, str or np.ndarray, or a list of list of one. + :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. :param int padding_val: the integer for padding. Default: 0. :param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_input: If True, this FieldArray is used to the model input. @@ -20,18 +20,44 @@ class FieldArray(object): if isinstance(content, list): content = content elif isinstance(content, np.ndarray): - content = content.tolist() + content = content.tolist() # convert np.ndarray into 2-D list else: raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) self.content = content self.padding_val = padding_val - self.is_target = is_target - self.is_input = is_input + + self._is_target = None + self._is_input = None self.BASIC_TYPES = (int, float, str, np.ndarray) self.is_2d_list = False - self.pytype = self._type_detection(content) + self.pytype = None # int, float, str, or np.ndarray + self.dtype = None # np.int64, np.float64, np.str + + if is_input is not None: + self.is_input = is_input + if is_target is not None: + self.is_target = is_target + + @property + def is_input(self): + return self._is_input + + @is_input.setter + def is_input(self, value): + self.pytype = self._type_detection(self.content) + self.dtype = self._map_to_np_type(self.pytype) + self._is_input = value + + @property + def is_target(self): + return self._is_target + + @is_target.setter + def is_target(self, value): + self.pytype = self._type_detection(self.content) self.dtype = self._map_to_np_type(self.pytype) + self._is_target = value def _type_detection(self, content): """ @@ -42,9 +68,13 @@ class FieldArray(object): """ if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): # content is a 2-D list + if not all(isinstance(_, list) for _ in content): # strict check 2-D list + raise TypeError("Please provide 2-D list.") type_set = set([self._type_detection(x) for x in content]) - if len(type_set) > 1: - raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) + if len(type_set) == 2 and int in type_set and float in type_set: + type_set = {float} + elif len(type_set) > 1: + raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) self.is_2d_list = True return type_set.pop() @@ -60,9 +90,9 @@ class FieldArray(object): # up-cast int to float return float else: - raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) + raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) else: - raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) + raise TypeError("Cannot create FieldArray with type {}".format(type(content))) @staticmethod def _map_to_np_type(basic_type): @@ -77,33 +107,38 @@ class FieldArray(object): :param val: int, float, str, or a list of one. """ - val_type = type(val) - if val_type == list: # shape check - if self.is_2d_list is False: - raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") - if len(val) == 0: - raise RuntimeError("Cannot append an empty list.") - val_list_type = set([type(_) for _ in val]) # type check - if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: - # up-cast int to float - val_type = float - elif len(val_list_type) == 1: - val_type = val_list_type.pop() + if self.is_target is True or self.is_input is True: + # only check type when used as target or input + + val_type = type(val) + if val_type == list: # shape check + if self.is_2d_list is False: + raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") + if len(val) == 0: + raise RuntimeError("Cannot append an empty list.") + val_list_type = set([type(_) for _ in val]) # type check + if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: + # up-cast int to float + val_type = float + elif len(val_list_type) == 1: + val_type = val_list_type.pop() + else: + raise TypeError("Cannot append a list of {}".format(val_list_type)) else: - raise RuntimeError("Cannot append a list of {}".format(val_list_type)) - else: - if self.is_2d_list is True: - raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") - if val_type == float and self.pytype == int: - # up-cast - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) - elif val_type == int and self.pytype == float: - pass - elif val_type == self.pytype: - pass - else: - raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) + if self.is_2d_list is True: + raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") + + if val_type == float and self.pytype == int: + # up-cast + self.pytype = float + self.dtype = self._map_to_np_type(self.pytype) + elif val_type == int and self.pytype == float: + pass + elif val_type == self.pytype: + pass + else: + raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) + self.content.append(val) def __getitem__(self, indices): @@ -121,7 +156,8 @@ class FieldArray(object): """ if isinstance(indices, int): return self.content[indices] - assert self.is_input is True or self.is_target is True + if self.is_input is False and self.is_target is False: + raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) batch_size = len(indices) # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 if not is_iterable(self.content[0]): diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 0264c2ff..c22bac5b 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -44,11 +44,34 @@ class TestFieldArray(unittest.TestCase): def test_support_np_array(self): fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False) self.assertEqual(fa.dtype, np.ndarray) + self.assertEqual(fa.pytype, np.ndarray) fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) + self.assertEqual(fa.dtype, np.ndarray) self.assertEqual(fa.pytype, np.ndarray) + fa = FieldArray("my_field", np.random.rand(3, 5), is_input=False) + # in this case, pytype is actually a float. We do not care about it. + self.assertEqual(fa.dtype, np.float64) + def test_nested_list(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False) self.assertEqual(fa.pytype, float) self.assertEqual(fa.dtype, np.float64) + + def test_getitem_v1(self): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) + ans = fa[[0, 1]] + self.assertTrue(isinstance(ans, np.ndarray)) + self.assertTrue(isinstance(ans[0], np.ndarray)) + self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) + self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) + self.assertEqual(ans.dtype, np.float64) + + def test_getitem_v2(self): + x = np.random.rand(10, 5) + fa = FieldArray("my_field", x, is_input=True) + indices = [0, 1, 3, 4, 6] + for a, b in zip(fa[indices], x[indices]): + self.assertListEqual(a.tolist(), b.tolist()) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 1b8ae70b..76352aba 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -1,9 +1,10 @@ - import unittest -from fastNLP.core.metrics import AccuracyMetric -import torch import numpy as np +import torch + +from fastNLP.core.metrics import AccuracyMetric + class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): @@ -12,9 +13,9 @@ class TestAccuracyMetric(unittest.TestCase): target_dict = {'target': torch.zeros(4)} metric = AccuracyMetric() - metric(pred_dict=pred_dict, target_dict=target_dict, ) + metric(pred_dict=pred_dict, target_dict=target_dict, ) print(metric.get_metric()) - # + def test_AccuracyMetric2(self): # (2) with corrupted size try: @@ -22,13 +23,13 @@ class TestAccuracyMetric(unittest.TestCase): target_dict = {'target': torch.zeros(4)} metric = AccuracyMetric() - metric(pred_dict=pred_dict, target_dict=target_dict, ) + metric(pred_dict=pred_dict, target_dict=target_dict, ) print(metric.get_metric()) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." - # + def test_AccuracyMetric3(self): # (3) the second batch is corrupted size try: @@ -47,7 +48,6 @@ class TestAccuracyMetric(unittest.TestCase): return self.assertTrue(True, False), "No exception catches." - # def test_AccuaryMetric4(self): # (5) check reset metric = AccuracyMetric() @@ -57,9 +57,9 @@ class TestAccuracyMetric(unittest.TestCase): self.assertDictEqual(metric.get_metric(), {'acc': 1}) pred_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)+1} + target_dict = {'target': torch.zeros(4, 3) + 1} metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc':0}) + self.assertDictEqual(metric.get_metric(), {'acc': 0}) def test_AccuaryMetric5(self): # (5) check reset @@ -70,11 +70,10 @@ class TestAccuracyMetric(unittest.TestCase): self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) pred_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)+1} + target_dict = {'target': torch.zeros(4, 3) + 1} metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc':0.5}) + self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) - # def test_AccuaryMetric6(self): # (6) check numpy array is not acceptable try: @@ -99,9 +98,9 @@ class TestAccuracyMetric(unittest.TestCase): # (8) check map, does not match. use stop_fast_param to stop fast param map try: metric = AccuracyMetric(pred='predictions', target='targets') - pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1} + pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} target_dict = {'targets': torch.zeros(4, 3)} - metric(pred_dict=pred_dict, target_dict=target_dict, ) + metric(pred_dict=pred_dict, target_dict=target_dict, ) self.assertDictEqual(metric.get_metric(), {'acc': 1}) except Exception as e: print(e) @@ -112,7 +111,7 @@ class TestAccuracyMetric(unittest.TestCase): # (9) check map, include unused try: metric = AccuracyMetric(pred='prediction', target='targets') - pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} + pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1})