@@ -67,8 +67,8 @@ class DataSet(object): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.idx = idx | self.idx = idx | ||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx) | |||||
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) | assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) | ||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||||
return self.dataset.field_arrays[item][self.idx] | return self.dataset.field_arrays[item][self.idx] | ||||
def __repr__(self): | def __repr__(self): | ||||
return self.dataset[self.idx].__repr__() | return self.dataset[self.idx].__repr__() | ||||
@@ -7,11 +7,11 @@ class FieldArray(object): | |||||
""" | """ | ||||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||||
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): | |||||
""" | """ | ||||
:param str name: the name of the FieldArray | :param str name: the name of the FieldArray | ||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one. | |||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||||
:param int padding_val: the integer for padding. Default: 0. | :param int padding_val: the integer for padding. Default: 0. | ||||
:param bool is_target: If True, this FieldArray is used to compute loss. | :param bool is_target: If True, this FieldArray is used to compute loss. | ||||
:param bool is_input: If True, this FieldArray is used to the model input. | :param bool is_input: If True, this FieldArray is used to the model input. | ||||
@@ -20,18 +20,44 @@ class FieldArray(object): | |||||
if isinstance(content, list): | if isinstance(content, list): | ||||
content = content | content = content | ||||
elif isinstance(content, np.ndarray): | elif isinstance(content, np.ndarray): | ||||
content = content.tolist() | |||||
content = content.tolist() # convert np.ndarray into 2-D list | |||||
else: | else: | ||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | ||||
self.content = content | self.content = content | ||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | |||||
self.is_input = is_input | |||||
self._is_target = None | |||||
self._is_input = None | |||||
self.BASIC_TYPES = (int, float, str, np.ndarray) | self.BASIC_TYPES = (int, float, str, np.ndarray) | ||||
self.is_2d_list = False | self.is_2d_list = False | ||||
self.pytype = self._type_detection(content) | |||||
self.pytype = None # int, float, str, or np.ndarray | |||||
self.dtype = None # np.int64, np.float64, np.str | |||||
if is_input is not None: | |||||
self.is_input = is_input | |||||
if is_target is not None: | |||||
self.is_target = is_target | |||||
@property | |||||
def is_input(self): | |||||
return self._is_input | |||||
@is_input.setter | |||||
def is_input(self, value): | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._is_input = value | |||||
@property | |||||
def is_target(self): | |||||
return self._is_target | |||||
@is_target.setter | |||||
def is_target(self, value): | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | self.dtype = self._map_to_np_type(self.pytype) | ||||
self._is_target = value | |||||
def _type_detection(self, content): | def _type_detection(self, content): | ||||
""" | """ | ||||
@@ -42,9 +68,13 @@ class FieldArray(object): | |||||
""" | """ | ||||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | ||||
# content is a 2-D list | # content is a 2-D list | ||||
if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||||
raise TypeError("Please provide 2-D list.") | |||||
type_set = set([self._type_detection(x) for x in content]) | type_set = set([self._type_detection(x) for x in content]) | ||||
if len(type_set) > 1: | |||||
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||||
if len(type_set) == 2 and int in type_set and float in type_set: | |||||
type_set = {float} | |||||
elif len(type_set) > 1: | |||||
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||||
self.is_2d_list = True | self.is_2d_list = True | ||||
return type_set.pop() | return type_set.pop() | ||||
@@ -60,9 +90,9 @@ class FieldArray(object): | |||||
# up-cast int to float | # up-cast int to float | ||||
return float | return float | ||||
else: | else: | ||||
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) | |||||
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||||
else: | else: | ||||
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) | |||||
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||||
@staticmethod | @staticmethod | ||||
def _map_to_np_type(basic_type): | def _map_to_np_type(basic_type): | ||||
@@ -77,33 +107,38 @@ class FieldArray(object): | |||||
:param val: int, float, str, or a list of one. | :param val: int, float, str, or a list of one. | ||||
""" | """ | ||||
val_type = type(val) | |||||
if val_type == list: # shape check | |||||
if self.is_2d_list is False: | |||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||||
if len(val) == 0: | |||||
raise RuntimeError("Cannot append an empty list.") | |||||
val_list_type = set([type(_) for _ in val]) # type check | |||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||||
# up-cast int to float | |||||
val_type = float | |||||
elif len(val_list_type) == 1: | |||||
val_type = val_list_type.pop() | |||||
if self.is_target is True or self.is_input is True: | |||||
# only check type when used as target or input | |||||
val_type = type(val) | |||||
if val_type == list: # shape check | |||||
if self.is_2d_list is False: | |||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||||
if len(val) == 0: | |||||
raise RuntimeError("Cannot append an empty list.") | |||||
val_list_type = set([type(_) for _ in val]) # type check | |||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||||
# up-cast int to float | |||||
val_type = float | |||||
elif len(val_list_type) == 1: | |||||
val_type = val_list_type.pop() | |||||
else: | |||||
raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||||
else: | else: | ||||
raise RuntimeError("Cannot append a list of {}".format(val_list_type)) | |||||
else: | |||||
if self.is_2d_list is True: | |||||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||||
if val_type == float and self.pytype == int: | |||||
# up-cast | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
elif val_type == int and self.pytype == float: | |||||
pass | |||||
elif val_type == self.pytype: | |||||
pass | |||||
else: | |||||
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||||
if self.is_2d_list is True: | |||||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||||
if val_type == float and self.pytype == int: | |||||
# up-cast | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
elif val_type == int and self.pytype == float: | |||||
pass | |||||
elif val_type == self.pytype: | |||||
pass | |||||
else: | |||||
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||||
self.content.append(val) | self.content.append(val) | ||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
@@ -121,7 +156,8 @@ class FieldArray(object): | |||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
return self.content[indices] | return self.content[indices] | ||||
assert self.is_input is True or self.is_target is True | |||||
if self.is_input is False and self.is_target is False: | |||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | |||||
batch_size = len(indices) | batch_size = len(indices) | ||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | ||||
if not is_iterable(self.content[0]): | if not is_iterable(self.content[0]): | ||||
@@ -118,7 +118,7 @@ class LossBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
for keys, val in pred_dict.items(): | for keys, val in pred_dict.items(): | ||||
if keys in target_dict.keys(): | if keys in target_dict.keys(): | ||||
duplicated.append(keys) | |||||
duplicated.append(param_map[keys]) | |||||
param_val_dict = {} | param_val_dict = {} | ||||
for keys, val in pred_dict.items(): | for keys, val in pred_dict.items(): | ||||
@@ -126,11 +126,10 @@ class LossBase(object): | |||||
for keys, val in target_dict.items(): | for keys, val in target_dict.items(): | ||||
param_val_dict.update({keys: val}) | param_val_dict.update({keys: val}) | ||||
# TODO: use the origin key to raise error | |||||
if not self._checked: | if not self._checked: | ||||
for keys in args: | for keys in args: | ||||
if param_map[keys] not in param_val_dict.keys(): | if param_map[keys] not in param_val_dict.keys(): | ||||
missing.append(keys) | |||||
missing.append(param_map[keys]) | |||||
if len(duplicated) > 0 or len(missing) > 0: | if len(duplicated) > 0 or len(missing) > 0: | ||||
raise CheckError( | raise CheckError( | ||||
@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module): | |||||
padding=padding) | padding=padding) | ||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | self.fc = encoder.Linear(sum(kernel_nums), num_classes) | ||||
self._loss = nn.CrossEntropyLoss() | |||||
def forward(self, word_seq): | def forward(self, word_seq): | ||||
""" | """ | ||||
@@ -56,25 +55,3 @@ class CNNText(torch.nn.Module): | |||||
output = self(word_seq) | output = self(word_seq) | ||||
_, predict = output['output'].max(dim=1) | _, predict = output['output'].max(dim=1) | ||||
return {'predict': predict} | return {'predict': predict} | ||||
def get_loss(self, output, label_seq): | |||||
""" | |||||
:param output: output of forward(), [batch_size, seq_len] | |||||
:param label_seq: true label in DataSet, [batch_size, seq_len] | |||||
:return loss: torch.Tensor | |||||
""" | |||||
return self._loss(output, label_seq) | |||||
def evaluate(self, predict, label_seq): | |||||
""" | |||||
:param predict: iterable predict tensors | |||||
:param label_seq: iterable true label tensors | |||||
:return accuracy: dict of float | |||||
""" | |||||
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) | |||||
predict, label_seq = predict.squeeze(), label_seq.squeeze() | |||||
correct = (predict == label_seq).long().sum().item() | |||||
total = label_seq.size(0) | |||||
return {'acc': 1.0 * correct / total} |
@@ -44,11 +44,34 @@ class TestFieldArray(unittest.TestCase): | |||||
def test_support_np_array(self): | def test_support_np_array(self): | ||||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False) | fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False) | ||||
self.assertEqual(fa.dtype, np.ndarray) | self.assertEqual(fa.dtype, np.ndarray) | ||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | ||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | self.assertEqual(fa.pytype, np.ndarray) | ||||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=False) | |||||
# in this case, pytype is actually a float. We do not care about it. | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
def test_nested_list(self): | def test_nested_list(self): | ||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False) | fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False) | ||||
self.assertEqual(fa.pytype, float) | self.assertEqual(fa.pytype, float) | ||||
self.assertEqual(fa.dtype, np.float64) | self.assertEqual(fa.dtype, np.float64) | ||||
def test_getitem_v1(self): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||||
ans = fa[[0, 1]] | |||||
self.assertTrue(isinstance(ans, np.ndarray)) | |||||
self.assertTrue(isinstance(ans[0], np.ndarray)) | |||||
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) | |||||
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) | |||||
self.assertEqual(ans.dtype, np.float64) | |||||
def test_getitem_v2(self): | |||||
x = np.random.rand(10, 5) | |||||
fa = FieldArray("my_field", x, is_input=True) | |||||
indices = [0, 1, 3, 4, 6] | |||||
for a, b in zip(fa[indices], x[indices]): | |||||
self.assertListEqual(a.tolist(), b.tolist()) |
@@ -319,3 +319,12 @@ class TestLosserError(unittest.TestCase): | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | print(los(pred_dict=pred_dict, target_dict=target_dict)) | ||||
def test_check_error(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) |
@@ -1,9 +1,10 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.core.metrics import AccuracyMetric | |||||
import torch | |||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
def test_AccuracyMetric1(self): | def test_AccuracyMetric1(self): | ||||
@@ -12,9 +13,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
# | |||||
def test_AccuracyMetric2(self): | def test_AccuracyMetric2(self): | ||||
# (2) with corrupted size | # (2) with corrupted size | ||||
try: | try: | ||||
@@ -22,13 +23,13 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
# | |||||
def test_AccuracyMetric3(self): | def test_AccuracyMetric3(self): | ||||
# (3) the second batch is corrupted size | # (3) the second batch is corrupted size | ||||
try: | try: | ||||
@@ -47,7 +48,6 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
# | |||||
def test_AccuaryMetric4(self): | def test_AccuaryMetric4(self): | ||||
# (5) check reset | # (5) check reset | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
@@ -57,9 +57,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4, 3)+1} | |||||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 0}) | |||||
def test_AccuaryMetric5(self): | def test_AccuaryMetric5(self): | ||||
# (5) check reset | # (5) check reset | ||||
@@ -70,11 +70,10 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | ||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4, 3)+1} | |||||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'acc':0.5}) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) | |||||
# | |||||
def test_AccuaryMetric6(self): | def test_AccuaryMetric6(self): | ||||
# (6) check numpy array is not acceptable | # (6) check numpy array is not acceptable | ||||
try: | try: | ||||
@@ -99,9 +98,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | # (8) check map, does not match. use stop_fast_param to stop fast param map | ||||
try: | try: | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1} | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
@@ -112,7 +111,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
# (9) check map, include unused | # (9) check map, include unused | ||||
try: | try: | ||||
metric = AccuracyMetric(pred='prediction', target='targets') | metric = AccuracyMetric(pred='prediction', target='targets') | ||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||