Browse Source

Improve FieldArray. Support nested list and a list of np.array

tags/v0.2.0^2
FengZiYjun 5 years ago
parent
commit
661780b975
3 changed files with 69 additions and 40 deletions
  1. +53
    -37
      fastNLP/core/fieldarray.py
  2. +1
    -0
      fastNLP/core/losses.py
  3. +15
    -3
      test/core/test_fieldarray.py

+ 53
- 37
fastNLP/core/fieldarray.py View File

@@ -11,7 +11,7 @@ class FieldArray(object):
""" """


:param str name: the name of the FieldArray :param str name: the name of the FieldArray
:param list content: a list of int, float, or a list of list.
:param list content: a list of int, float, str or np.ndarray, or a list of list of one.
:param int padding_val: the integer for padding. Default: 0. :param int padding_val: the integer for padding. Default: 0.
:param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input. :param bool is_input: If True, this FieldArray is used to the model input.
@@ -27,35 +27,46 @@ class FieldArray(object):
self.padding_val = padding_val self.padding_val = padding_val
self.is_target = is_target self.is_target = is_target
self.is_input = is_input self.is_input = is_input

self.BASIC_TYPES = (int, float, str, np.ndarray)
self.is_2d_list = False
self.pytype = self._type_detection(content) self.pytype = self._type_detection(content)
self.dtype = self._map_to_np_type(self.pytype) self.dtype = self._map_to_np_type(self.pytype)


@staticmethod
def _type_detection(content):
def _type_detection(self, content):
"""


:param content: a list of int, float, str or np.ndarray, or a list of list of one.
:return type: one of int, float, str, np.ndarray

"""
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list):
# 2-D list
# TODO: refactor
type_set = set([type(item) for item in content[0]])
else:
# 1-D list
# content is a 2-D list
type_set = set([self._type_detection(x) for x in content])
if len(type_set) > 1:
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set))
self.is_2d_list = True
return type_set.pop()

elif isinstance(content, list):
# content is a 1-D list
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Cannot create FieldArray with an empty list.") raise RuntimeError("Cannot create FieldArray with an empty list.")
type_set = set([type(item) for item in content]) type_set = set([type(item) for item in content])


if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)):
return type_set.pop()
elif len(type_set) == 2 and float in type_set and int in type_set:
# up-cast int to float
for idx, _ in enumerate(content):
content[idx] = float(content[idx])
return float
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES:
return type_set.pop()
elif len(type_set) == 2 and float in type_set and int in type_set:
# up-cast int to float
return float
else:
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set))
else: else:
raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set))
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content)))


@staticmethod @staticmethod
def _map_to_np_type(basic_type): def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.float64, str: np.str}
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray}
return type_mapping[basic_type] return type_mapping[basic_type]


def __repr__(self): def __repr__(self):
@@ -64,29 +75,35 @@ class FieldArray(object):
def append(self, val): def append(self, val):
"""Add a new item to the tail of FieldArray. """Add a new item to the tail of FieldArray.


:param val: int, float, str, or a list of them.
:param val: int, float, str, or a list of one.
""" """
val_type = type(val) val_type = type(val)
if val_type is int and self.pytype is float:
# up-cast the appended value
val = float(val)
elif val_type is float and self.pytype is int:
# up-cast all other values in the content
for idx, _ in enumerate(self.content):
self.content[idx] = float(self.content[idx])
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
elif val_type is list:
if val_type == list: # shape check
if self.is_2d_list is False:
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
if len(val) == 0: if len(val) == 0:
raise ValueError("Cannot append an empty list.")
raise RuntimeError("Cannot append an empty list.")
val_list_type = [type(_) for _ in val] # type check
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
# up-cast int to float
val_type = float
elif len(val_list_type) == 1:
val_type = val_list_type[0]
else: else:
if type(val[0]) != self.pytype:
raise ValueError(
"Cannot append a list of {}-type value into a {}-tpye FieldArray.".
format(type(val[0]), self.pytype))
elif val_type != self.pytype:
raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype))

raise RuntimeError("Cannot append a list of {}".format(val_list_type))
else:
if self.is_2d_list is True:
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.")
if val_type == float and self.pytype == int:
# up-cast
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
elif val_type == int and self.pytype == float:
pass
elif val_type == self.pytype:
pass
else:
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype))
self.content.append(val) self.content.append(val)


def __getitem__(self, indices): def __getitem__(self, indices):
@@ -102,7 +119,6 @@ class FieldArray(object):
:param indices: an int, or a list of int. :param indices: an int, or a list of int.
:return: :return:
""" """
# TODO: 返回行为不一致,有隐患
if isinstance(indices, int): if isinstance(indices, int):
return self.content[indices] return self.content[indices]
assert self.is_input is True or self.is_target is True assert self.is_input is True or self.is_target is True


+ 1
- 0
fastNLP/core/losses.py View File

@@ -126,6 +126,7 @@ class LossBase(object):
for keys, val in target_dict.items(): for keys, val in target_dict.items():
param_val_dict.update({keys: val}) param_val_dict.update({keys: val})


# TODO: use the origin key to raise error
if not self._checked: if not self._checked:
for keys in args: for keys in args:
if param_map[keys] not in param_val_dict.keys(): if param_map[keys] not in param_val_dict.keys():


+ 15
- 3
test/core/test_fieldarray.py View File

@@ -24,19 +24,31 @@ class TestFieldArray(unittest.TestCase):
def test_type_conversion(self): def test_type_conversion(self):
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
self.assertEqual(fa.pytype, float) self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.double)
self.assertEqual(fa.dtype, np.float64)


fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
fa.append(1.3333) fa.append(1.3333)
self.assertEqual(fa.pytype, float) self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.double)
self.assertEqual(fa.dtype, np.float64)


fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False) fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False)
fa.append(10) fa.append(10)
self.assertEqual(fa.pytype, float) self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.double)
self.assertEqual(fa.dtype, np.float64)


fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False) fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False)
fa.append("e") fa.append("e")
self.assertEqual(fa.dtype, np.str) self.assertEqual(fa.dtype, np.str)
self.assertEqual(fa.pytype, str) self.assertEqual(fa.pytype, str)

def test_support_np_array(self):
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False)
self.assertEqual(fa.dtype, np.ndarray)

fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
self.assertEqual(fa.pytype, np.ndarray)

def test_nested_list(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

Loading…
Cancel
Save