@@ -67,8 +67,8 @@ class DataSet(object): | |||
self.dataset = dataset | |||
self.idx = idx | |||
def __getitem__(self, item): | |||
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx) | |||
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) | |||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||
return self.dataset.field_arrays[item][self.idx] | |||
def __repr__(self): | |||
return self.dataset[self.idx].__repr__() | |||
@@ -7,11 +7,11 @@ class FieldArray(object): | |||
""" | |||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): | |||
""" | |||
:param str name: the name of the FieldArray | |||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one. | |||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||
:param int padding_val: the integer for padding. Default: 0. | |||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||
:param bool is_input: If True, this FieldArray is used to the model input. | |||
@@ -20,18 +20,44 @@ class FieldArray(object): | |||
if isinstance(content, list): | |||
content = content | |||
elif isinstance(content, np.ndarray): | |||
content = content.tolist() | |||
content = content.tolist() # convert np.ndarray into 2-D list | |||
else: | |||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||
self.content = content | |||
self.padding_val = padding_val | |||
self.is_target = is_target | |||
self.is_input = is_input | |||
self._is_target = None | |||
self._is_input = None | |||
self.BASIC_TYPES = (int, float, str, np.ndarray) | |||
self.is_2d_list = False | |||
self.pytype = self._type_detection(content) | |||
self.pytype = None # int, float, str, or np.ndarray | |||
self.dtype = None # np.int64, np.float64, np.str | |||
if is_input is not None: | |||
self.is_input = is_input | |||
if is_target is not None: | |||
self.is_target = is_target | |||
@property | |||
def is_input(self): | |||
return self._is_input | |||
@is_input.setter | |||
def is_input(self, value): | |||
self.pytype = self._type_detection(self.content) | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
self._is_input = value | |||
@property | |||
def is_target(self): | |||
return self._is_target | |||
@is_target.setter | |||
def is_target(self, value): | |||
self.pytype = self._type_detection(self.content) | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
self._is_target = value | |||
def _type_detection(self, content): | |||
""" | |||
@@ -42,9 +68,13 @@ class FieldArray(object): | |||
""" | |||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | |||
# content is a 2-D list | |||
if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||
raise TypeError("Please provide 2-D list.") | |||
type_set = set([self._type_detection(x) for x in content]) | |||
if len(type_set) > 1: | |||
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||
if len(type_set) == 2 and int in type_set and float in type_set: | |||
type_set = {float} | |||
elif len(type_set) > 1: | |||
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||
self.is_2d_list = True | |||
return type_set.pop() | |||
@@ -60,9 +90,9 @@ class FieldArray(object): | |||
# up-cast int to float | |||
return float | |||
else: | |||
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) | |||
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||
else: | |||
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) | |||
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||
@staticmethod | |||
def _map_to_np_type(basic_type): | |||
@@ -77,33 +107,38 @@ class FieldArray(object): | |||
:param val: int, float, str, or a list of one. | |||
""" | |||
val_type = type(val) | |||
if val_type == list: # shape check | |||
if self.is_2d_list is False: | |||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||
if len(val) == 0: | |||
raise RuntimeError("Cannot append an empty list.") | |||
val_list_type = set([type(_) for _ in val]) # type check | |||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||
# up-cast int to float | |||
val_type = float | |||
elif len(val_list_type) == 1: | |||
val_type = val_list_type.pop() | |||
if self.is_target is True or self.is_input is True: | |||
# only check type when used as target or input | |||
val_type = type(val) | |||
if val_type == list: # shape check | |||
if self.is_2d_list is False: | |||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||
if len(val) == 0: | |||
raise RuntimeError("Cannot append an empty list.") | |||
val_list_type = set([type(_) for _ in val]) # type check | |||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||
# up-cast int to float | |||
val_type = float | |||
elif len(val_list_type) == 1: | |||
val_type = val_list_type.pop() | |||
else: | |||
raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||
else: | |||
raise RuntimeError("Cannot append a list of {}".format(val_list_type)) | |||
else: | |||
if self.is_2d_list is True: | |||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||
if val_type == float and self.pytype == int: | |||
# up-cast | |||
self.pytype = float | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
elif val_type == int and self.pytype == float: | |||
pass | |||
elif val_type == self.pytype: | |||
pass | |||
else: | |||
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||
if self.is_2d_list is True: | |||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||
if val_type == float and self.pytype == int: | |||
# up-cast | |||
self.pytype = float | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
elif val_type == int and self.pytype == float: | |||
pass | |||
elif val_type == self.pytype: | |||
pass | |||
else: | |||
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||
self.content.append(val) | |||
def __getitem__(self, indices): | |||
@@ -121,7 +156,8 @@ class FieldArray(object): | |||
""" | |||
if isinstance(indices, int): | |||
return self.content[indices] | |||
assert self.is_input is True or self.is_target is True | |||
if self.is_input is False and self.is_target is False: | |||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | |||
batch_size = len(indices) | |||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||
if not is_iterable(self.content[0]): | |||
@@ -118,7 +118,7 @@ class LossBase(object): | |||
if not self._checked: | |||
for keys, val in pred_dict.items(): | |||
if keys in target_dict.keys(): | |||
duplicated.append(keys) | |||
duplicated.append(param_map[keys]) | |||
param_val_dict = {} | |||
for keys, val in pred_dict.items(): | |||
@@ -126,11 +126,10 @@ class LossBase(object): | |||
for keys, val in target_dict.items(): | |||
param_val_dict.update({keys: val}) | |||
# TODO: use the origin key to raise error | |||
if not self._checked: | |||
for keys in args: | |||
if param_map[keys] not in param_val_dict.keys(): | |||
missing.append(keys) | |||
missing.append(param_map[keys]) | |||
if len(duplicated) > 0 or len(missing) > 0: | |||
raise CheckError( | |||
@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module): | |||
padding=padding) | |||
self.dropout = nn.Dropout(dropout) | |||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | |||
self._loss = nn.CrossEntropyLoss() | |||
def forward(self, word_seq): | |||
""" | |||
@@ -56,25 +55,3 @@ class CNNText(torch.nn.Module): | |||
output = self(word_seq) | |||
_, predict = output['output'].max(dim=1) | |||
return {'predict': predict} | |||
def get_loss(self, output, label_seq): | |||
""" | |||
:param output: output of forward(), [batch_size, seq_len] | |||
:param label_seq: true label in DataSet, [batch_size, seq_len] | |||
:return loss: torch.Tensor | |||
""" | |||
return self._loss(output, label_seq) | |||
def evaluate(self, predict, label_seq): | |||
""" | |||
:param predict: iterable predict tensors | |||
:param label_seq: iterable true label tensors | |||
:return accuracy: dict of float | |||
""" | |||
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) | |||
predict, label_seq = predict.squeeze(), label_seq.squeeze() | |||
correct = (predict == label_seq).long().sum().item() | |||
total = label_seq.size(0) | |||
return {'acc': 1.0 * correct / total} |
@@ -44,11 +44,34 @@ class TestFieldArray(unittest.TestCase): | |||
def test_support_np_array(self): | |||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False) | |||
self.assertEqual(fa.dtype, np.ndarray) | |||
self.assertEqual(fa.pytype, np.ndarray) | |||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||
self.assertEqual(fa.dtype, np.ndarray) | |||
self.assertEqual(fa.pytype, np.ndarray) | |||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=False) | |||
# in this case, pytype is actually a float. We do not care about it. | |||
self.assertEqual(fa.dtype, np.float64) | |||
def test_nested_list(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False) | |||
self.assertEqual(fa.pytype, float) | |||
self.assertEqual(fa.dtype, np.float64) | |||
def test_getitem_v1(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
ans = fa[[0, 1]] | |||
self.assertTrue(isinstance(ans, np.ndarray)) | |||
self.assertTrue(isinstance(ans[0], np.ndarray)) | |||
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) | |||
self.assertEqual(ans.dtype, np.float64) | |||
def test_getitem_v2(self): | |||
x = np.random.rand(10, 5) | |||
fa = FieldArray("my_field", x, is_input=True) | |||
indices = [0, 1, 3, 4, 6] | |||
for a, b in zip(fa[indices], x[indices]): | |||
self.assertListEqual(a.tolist(), b.tolist()) |
@@ -319,3 +319,12 @@ class TestLosserError(unittest.TestCase): | |||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
def test_check_error(self): | |||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||
b = torch.tensor([1, 0, 4]) | |||
with self.assertRaises(Exception): | |||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||
with self.assertRaises(Exception): | |||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) |
@@ -1,9 +1,10 @@ | |||
import unittest | |||
from fastNLP.core.metrics import AccuracyMetric | |||
import torch | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.metrics import AccuracyMetric | |||
class TestAccuracyMetric(unittest.TestCase): | |||
def test_AccuracyMetric1(self): | |||
@@ -12,9 +13,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = AccuracyMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
# | |||
def test_AccuracyMetric2(self): | |||
# (2) with corrupted size | |||
try: | |||
@@ -22,13 +23,13 @@ class TestAccuracyMetric(unittest.TestCase): | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = AccuracyMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
# | |||
def test_AccuracyMetric3(self): | |||
# (3) the second batch is corrupted size | |||
try: | |||
@@ -47,7 +48,6 @@ class TestAccuracyMetric(unittest.TestCase): | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
# | |||
def test_AccuaryMetric4(self): | |||
# (5) check reset | |||
metric = AccuracyMetric() | |||
@@ -57,9 +57,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)+1} | |||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 0}) | |||
def test_AccuaryMetric5(self): | |||
# (5) check reset | |||
@@ -70,11 +70,10 @@ class TestAccuracyMetric(unittest.TestCase): | |||
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)+1} | |||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc':0.5}) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) | |||
# | |||
def test_AccuaryMetric6(self): | |||
# (6) check numpy array is not acceptable | |||
try: | |||
@@ -99,9 +98,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||
try: | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1} | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
except Exception as e: | |||
print(e) | |||
@@ -112,7 +111,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# (9) check map, include unused | |||
try: | |||
metric = AccuracyMetric(pred='prediction', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||