diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 1b1a89c1..a1ece0aa 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -11,7 +11,7 @@ class FieldArray(object): """ :param str name: the name of the FieldArray - :param list content: a list of int, float, or a list of list. + :param list content: a list of int, float, str or np.ndarray, or a list of list of one. :param int padding_val: the integer for padding. Default: 0. :param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_input: If True, this FieldArray is used to the model input. @@ -27,35 +27,46 @@ class FieldArray(object): self.padding_val = padding_val self.is_target = is_target self.is_input = is_input + + self.BASIC_TYPES = (int, float, str, np.ndarray) + self.is_2d_list = False self.pytype = self._type_detection(content) self.dtype = self._map_to_np_type(self.pytype) - @staticmethod - def _type_detection(content): + def _type_detection(self, content): + """ + :param content: a list of int, float, str or np.ndarray, or a list of list of one. + :return type: one of int, float, str, np.ndarray + + """ if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): - # 2-D list - # TODO: refactor - type_set = set([type(item) for item in content[0]]) - else: - # 1-D list + # content is a 2-D list + type_set = set([self._type_detection(x) for x in content]) + if len(type_set) > 1: + raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) + self.is_2d_list = True + return type_set.pop() + + elif isinstance(content, list): + # content is a 1-D list if len(content) == 0: raise RuntimeError("Cannot create FieldArray with an empty list.") type_set = set([type(item) for item in content]) - if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)): - return type_set.pop() - elif len(type_set) == 2 and float in type_set and int in type_set: - # up-cast int to float - for idx, _ in enumerate(content): - content[idx] = float(content[idx]) - return float + if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: + return type_set.pop() + elif len(type_set) == 2 and float in type_set and int in type_set: + # up-cast int to float + return float + else: + raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) else: - raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set)) + raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) @staticmethod def _map_to_np_type(basic_type): - type_mapping = {int: np.int64, float: np.float64, str: np.str} + type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} return type_mapping[basic_type] def __repr__(self): @@ -64,29 +75,35 @@ class FieldArray(object): def append(self, val): """Add a new item to the tail of FieldArray. - :param val: int, float, str, or a list of them. + :param val: int, float, str, or a list of one. """ val_type = type(val) - if val_type is int and self.pytype is float: - # up-cast the appended value - val = float(val) - elif val_type is float and self.pytype is int: - # up-cast all other values in the content - for idx, _ in enumerate(self.content): - self.content[idx] = float(self.content[idx]) - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) - elif val_type is list: + if val_type == list: # shape check + if self.is_2d_list is False: + raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") if len(val) == 0: - raise ValueError("Cannot append an empty list.") + raise RuntimeError("Cannot append an empty list.") + val_list_type = [type(_) for _ in val] # type check + if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: + # up-cast int to float + val_type = float + elif len(val_list_type) == 1: + val_type = val_list_type[0] else: - if type(val[0]) != self.pytype: - raise ValueError( - "Cannot append a list of {}-type value into a {}-tpye FieldArray.". - format(type(val[0]), self.pytype)) - elif val_type != self.pytype: - raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype)) - + raise RuntimeError("Cannot append a list of {}".format(val_list_type)) + else: + if self.is_2d_list is True: + raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") + if val_type == float and self.pytype == int: + # up-cast + self.pytype = float + self.dtype = self._map_to_np_type(self.pytype) + elif val_type == int and self.pytype == float: + pass + elif val_type == self.pytype: + pass + else: + raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) self.content.append(val) def __getitem__(self, indices): @@ -102,7 +119,6 @@ class FieldArray(object): :param indices: an int, or a list of int. :return: """ - # TODO: 返回行为不一致,有隐患 if isinstance(indices, int): return self.content[indices] assert self.is_input is True or self.is_target is True diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index f2fb16d0..af3d2ef0 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -126,6 +126,7 @@ class LossBase(object): for keys, val in target_dict.items(): param_val_dict.update({keys: val}) + # TODO: use the origin key to raise error if not self._checked: for keys in args: if param_map[keys] not in param_val_dict.keys(): diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 883e1136..0264c2ff 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -24,19 +24,31 @@ class TestFieldArray(unittest.TestCase): def test_type_conversion(self): fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.double) + self.assertEqual(fa.dtype, np.float64) fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) fa.append(1.3333) self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.double) + self.assertEqual(fa.dtype, np.float64) fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False) fa.append(10) self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.double) + self.assertEqual(fa.dtype, np.float64) fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False) fa.append("e") self.assertEqual(fa.dtype, np.str) self.assertEqual(fa.pytype, str) + + def test_support_np_array(self): + fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False) + self.assertEqual(fa.dtype, np.ndarray) + + fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) + self.assertEqual(fa.pytype, np.ndarray) + + def test_nested_list(self): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False) + self.assertEqual(fa.pytype, float) + self.assertEqual(fa.dtype, np.float64)