Browse Source

Add auto type detection/conversion in FieldArray

* In init, detect content type to be Python int, float, or str.
* In append(), check type consistence.
* In init & append(), int will be cast into float if they occur together.
* Map Python type into numpy dtype
* Raise error if type detection fails.
tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
6839bb91cc
2 changed files with 72 additions and 15 deletions
  1. +52
    -15
      fastNLP/core/fieldarray.py
  2. +20
    -0
      test/core/test_fieldarray.py

+ 52
- 15
fastNLP/core/fieldarray.py View File

@@ -6,6 +6,7 @@ class FieldArray(object):
It is the basic element of DataSet class.

"""

def __init__(self, name, content, padding_val=0, is_target=False, is_input=False):
"""

@@ -20,21 +21,56 @@ class FieldArray(object):
self.padding_val = padding_val
self.is_target = is_target
self.is_input = is_input
# TODO: auto detect dtype
self.dtype = None
self.pytype = self._type_detection(content)
self.dtype = self._map_to_np_type(self.pytype)

@staticmethod
def _type_detection(content):
type_set = set([type(item) for item in content])
if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)):
return type_set.pop()
elif len(type_set) == 2 and float in type_set and int in type_set:
# up-cast int to float
for idx, _ in enumerate(content):
content[idx] = float(content[idx])
return float
else:
raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set))

@staticmethod
def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.double, str: np.str}
return type_mapping[basic_type]

def __repr__(self):
return "FieldArray {}: {}".format(self.name, self.content.__repr__())

def append(self, val):
"""Add a new item to the tail of FieldArray.

:param val: int, float, or str.
"""
val_type = type(val)
if val_type is int and self.pytype is float:
# up-cast the appended value
val = float(val)
elif val_type is float and self.pytype is int:
# up-cast all other values in the content
for idx, _ in enumerate(self.content):
self.content[idx] = float(self.content[idx])
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)

elif val_type != self.pytype:
raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype))
self.content.append(val)

def __getitem__(self, name):
return self.get(name)
def __getitem__(self, indices):
return self.get(indices)

def __setitem__(self, name, val):
assert isinstance(name, int)
self.content[name] = val
def __setitem__(self, idx, val):
assert isinstance(idx, int)
self.content[idx] = val

def get(self, indices):
"""Fetch instances based on indices.
@@ -42,31 +78,32 @@ class FieldArray(object):
:param indices: an int, or a list of int.
:return:
"""
# TODO: 返回行为不一致,有隐患
if isinstance(indices, int):
return self.content[indices]
assert self.is_input is True or self.is_target is True
batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if not isiterable(self.content[0]):
if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double
if not is_iterable(self.content[0]):
array = np.array([self.content[i] for i in indices], dtype=self.dtype)
else:
if self.dtype is None:
self.dtype = np.int64
max_len = max([len(self.content[i]) for i in indices])
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)

for i, idx in enumerate(indices):
array[i][:len(self.content[idx])] = self.content[idx]
return array

def __len__(self):
"""Returns the size of FieldArray.

:return int length:
"""
return len(self.content)

def isiterable(content):

def is_iterable(content):
try:
_ = (e for e in content)
except TypeError:
return False
return True
return True

+ 20
- 0
test/core/test_fieldarray.py View File

@@ -20,3 +20,23 @@ class TestFieldArray(unittest.TestCase):
self.assertEqual(fa.get(0), 1)
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])

def test_type_conversion(self):
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.double)

fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
fa.append(1.3333)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.double)

fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False)
fa.append(10)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.double)

fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False)
fa.append("e")
self.assertEqual(fa.dtype, np.str)
self.assertEqual(fa.pytype, str)

Loading…
Cancel
Save