Browse Source

FieldArray only check type when is_input or is_target is set.

tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
27833d06ae
3 changed files with 111 additions and 53 deletions
  1. +73
    -37
      fastNLP/core/fieldarray.py
  2. +23
    -0
      test/core/test_fieldarray.py
  3. +15
    -16
      test/core/test_metrics.py

+ 73
- 37
fastNLP/core/fieldarray.py View File

@@ -7,11 +7,11 @@ class FieldArray(object):

"""

def __init__(self, name, content, padding_val=0, is_target=False, is_input=False):
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None):
"""

:param str name: the name of the FieldArray
:param list content: a list of int, float, str or np.ndarray, or a list of list of one.
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param int padding_val: the integer for padding. Default: 0.
:param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input.
@@ -20,18 +20,44 @@ class FieldArray(object):
if isinstance(content, list):
content = content
elif isinstance(content, np.ndarray):
content = content.tolist()
content = content.tolist() # convert np.ndarray into 2-D list
else:
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content)))
self.content = content
self.padding_val = padding_val
self.is_target = is_target
self.is_input = is_input

self._is_target = None
self._is_input = None

self.BASIC_TYPES = (int, float, str, np.ndarray)
self.is_2d_list = False
self.pytype = self._type_detection(content)
self.pytype = None # int, float, str, or np.ndarray
self.dtype = None # np.int64, np.float64, np.str

if is_input is not None:
self.is_input = is_input
if is_target is not None:
self.is_target = is_target

@property
def is_input(self):
return self._is_input

@is_input.setter
def is_input(self, value):
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)
self._is_input = value

@property
def is_target(self):
return self._is_target

@is_target.setter
def is_target(self, value):
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)
self._is_target = value

def _type_detection(self, content):
"""
@@ -42,9 +68,13 @@ class FieldArray(object):
"""
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list):
# content is a 2-D list
if not all(isinstance(_, list) for _ in content): # strict check 2-D list
raise TypeError("Please provide 2-D list.")
type_set = set([self._type_detection(x) for x in content])
if len(type_set) > 1:
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set))
if len(type_set) == 2 and int in type_set and float in type_set:
type_set = {float}
elif len(type_set) > 1:
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set))
self.is_2d_list = True
return type_set.pop()

@@ -60,9 +90,9 @@ class FieldArray(object):
# up-cast int to float
return float
else:
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set))
raise TypeError("Cannot create FieldArray with type {}".format(*type_set))
else:
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content)))
raise TypeError("Cannot create FieldArray with type {}".format(type(content)))

@staticmethod
def _map_to_np_type(basic_type):
@@ -77,33 +107,38 @@ class FieldArray(object):

:param val: int, float, str, or a list of one.
"""
val_type = type(val)
if val_type == list: # shape check
if self.is_2d_list is False:
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
if len(val) == 0:
raise RuntimeError("Cannot append an empty list.")
val_list_type = set([type(_) for _ in val]) # type check
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
# up-cast int to float
val_type = float
elif len(val_list_type) == 1:
val_type = val_list_type.pop()
if self.is_target is True or self.is_input is True:
# only check type when used as target or input

val_type = type(val)
if val_type == list: # shape check
if self.is_2d_list is False:
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
if len(val) == 0:
raise RuntimeError("Cannot append an empty list.")
val_list_type = set([type(_) for _ in val]) # type check
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
# up-cast int to float
val_type = float
elif len(val_list_type) == 1:
val_type = val_list_type.pop()
else:
raise TypeError("Cannot append a list of {}".format(val_list_type))
else:
raise RuntimeError("Cannot append a list of {}".format(val_list_type))
else:
if self.is_2d_list is True:
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.")
if val_type == float and self.pytype == int:
# up-cast
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
elif val_type == int and self.pytype == float:
pass
elif val_type == self.pytype:
pass
else:
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype))
if self.is_2d_list is True:
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.")
if val_type == float and self.pytype == int:
# up-cast
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
elif val_type == int and self.pytype == float:
pass
elif val_type == self.pytype:
pass
else:
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype))
self.content.append(val)

def __getitem__(self, indices):
@@ -121,7 +156,8 @@ class FieldArray(object):
"""
if isinstance(indices, int):
return self.content[indices]
assert self.is_input is True or self.is_target is True
if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name))
batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if not is_iterable(self.content[0]):


+ 23
- 0
test/core/test_fieldarray.py View File

@@ -44,11 +44,34 @@ class TestFieldArray(unittest.TestCase):
def test_support_np_array(self):
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False)
self.assertEqual(fa.dtype, np.ndarray)
self.assertEqual(fa.pytype, np.ndarray)

fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
self.assertEqual(fa.dtype, np.ndarray)
self.assertEqual(fa.pytype, np.ndarray)

fa = FieldArray("my_field", np.random.rand(3, 5), is_input=False)
# in this case, pytype is actually a float. We do not care about it.
self.assertEqual(fa.dtype, np.float64)

def test_nested_list(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray))
self.assertTrue(isinstance(ans[0], np.ndarray))
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
self.assertEqual(ans.dtype, np.float64)

def test_getitem_v2(self):
x = np.random.rand(10, 5)
fa = FieldArray("my_field", x, is_input=True)
indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())

+ 15
- 16
test/core/test_metrics.py View File

@@ -1,9 +1,10 @@

import unittest

from fastNLP.core.metrics import AccuracyMetric
import torch
import numpy as np
import torch

from fastNLP.core.metrics import AccuracyMetric


class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
@@ -12,9 +13,9 @@ class TestAccuracyMetric(unittest.TestCase):
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
#
def test_AccuracyMetric2(self):
# (2) with corrupted size
try:
@@ -22,13 +23,13 @@ class TestAccuracyMetric(unittest.TestCase):
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
#
def test_AccuracyMetric3(self):
# (3) the second batch is corrupted size
try:
@@ -47,7 +48,6 @@ class TestAccuracyMetric(unittest.TestCase):
return
self.assertTrue(True, False), "No exception catches."

#
def test_AccuaryMetric4(self):
# (5) check reset
metric = AccuracyMetric()
@@ -57,9 +57,9 @@ class TestAccuracyMetric(unittest.TestCase):
self.assertDictEqual(metric.get_metric(), {'acc': 1})

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)+1}
target_dict = {'target': torch.zeros(4, 3) + 1}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc':0})
self.assertDictEqual(metric.get_metric(), {'acc': 0})

def test_AccuaryMetric5(self):
# (5) check reset
@@ -70,11 +70,10 @@ class TestAccuracyMetric(unittest.TestCase):
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1})

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)+1}
target_dict = {'target': torch.zeros(4, 3) + 1}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc':0.5})
self.assertDictEqual(metric.get_metric(), {'acc': 0.5})

#
def test_AccuaryMetric6(self):
# (6) check numpy array is not acceptable
try:
@@ -99,9 +98,9 @@ class TestAccuracyMetric(unittest.TestCase):
# (8) check map, does not match. use stop_fast_param to stop fast param map
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1}
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, )
metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
@@ -112,7 +111,7 @@ class TestAccuracyMetric(unittest.TestCase):
# (9) check map, include unused
try:
metric = AccuracyMetric(pred='prediction', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1}
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})


Loading…
Cancel
Save