diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 3a63f788..f93fbf2e 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -6,6 +6,7 @@ class FieldArray(object): It is the basic element of DataSet class. """ + def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): """ @@ -20,21 +21,56 @@ class FieldArray(object): self.padding_val = padding_val self.is_target = is_target self.is_input = is_input - # TODO: auto detect dtype - self.dtype = None + self.pytype = self._type_detection(content) + self.dtype = self._map_to_np_type(self.pytype) + + @staticmethod + def _type_detection(content): + type_set = set([type(item) for item in content]) + if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)): + return type_set.pop() + elif len(type_set) == 2 and float in type_set and int in type_set: + # up-cast int to float + for idx, _ in enumerate(content): + content[idx] = float(content[idx]) + return float + else: + raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set)) + + @staticmethod + def _map_to_np_type(basic_type): + type_mapping = {int: np.int64, float: np.double, str: np.str} + return type_mapping[basic_type] def __repr__(self): return "FieldArray {}: {}".format(self.name, self.content.__repr__()) def append(self, val): + """Add a new item to the tail of FieldArray. + + :param val: int, float, or str. + """ + val_type = type(val) + if val_type is int and self.pytype is float: + # up-cast the appended value + val = float(val) + elif val_type is float and self.pytype is int: + # up-cast all other values in the content + for idx, _ in enumerate(self.content): + self.content[idx] = float(self.content[idx]) + self.pytype = float + self.dtype = self._map_to_np_type(self.pytype) + + elif val_type != self.pytype: + raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype)) self.content.append(val) - def __getitem__(self, name): - return self.get(name) + def __getitem__(self, indices): + return self.get(indices) - def __setitem__(self, name, val): - assert isinstance(name, int) - self.content[name] = val + def __setitem__(self, idx, val): + assert isinstance(idx, int) + self.content[idx] = val def get(self, indices): """Fetch instances based on indices. @@ -42,31 +78,32 @@ class FieldArray(object): :param indices: an int, or a list of int. :return: """ + # TODO: 返回行为不一致,有隐患 if isinstance(indices, int): return self.content[indices] assert self.is_input is True or self.is_target is True batch_size = len(indices) # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 - if not isiterable(self.content[0]): - if self.dtype is None: - self.dtype = np.int64 if isinstance(self.content[0], int) else np.double + if not is_iterable(self.content[0]): array = np.array([self.content[i] for i in indices], dtype=self.dtype) else: - if self.dtype is None: - self.dtype = np.int64 max_len = max([len(self.content[i]) for i in indices]) array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) - for i, idx in enumerate(indices): array[i][:len(self.content[idx])] = self.content[idx] return array def __len__(self): + """Returns the size of FieldArray. + + :return int length: + """ return len(self.content) -def isiterable(content): + +def is_iterable(content): try: _ = (e for e in content) except TypeError: return False - return True \ No newline at end of file + return True diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 07f02c54..883e1136 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -20,3 +20,23 @@ class TestFieldArray(unittest.TestCase): self.assertEqual(fa.get(0), 1) self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) + + def test_type_conversion(self): + fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) + self.assertEqual(fa.pytype, float) + self.assertEqual(fa.dtype, np.double) + + fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) + fa.append(1.3333) + self.assertEqual(fa.pytype, float) + self.assertEqual(fa.dtype, np.double) + + fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False) + fa.append(10) + self.assertEqual(fa.pytype, float) + self.assertEqual(fa.dtype, np.double) + + fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False) + fa.append("e") + self.assertEqual(fa.dtype, np.str) + self.assertEqual(fa.pytype, str)