@@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result):
return allen_result
class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric1(self):
pred_dict = {"pred": torch.zeros(4,3)}
@@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric2(self):
# (2) with corrupted size
with self.assertRaises(Exception):
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
@@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase):
print(metric.get_metric())
def test_ConfusionMatrixMetric4(self):
# (4) check reset
metric = ConfusionMatrixMetric()
@@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric5(self):
# (5) check numpy array is not acceptable
with self.assertRaises(Exception):
metric = ConfusionMatrixMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))}
@@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
@@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
def test_seq_len(self):
N = 256
seq_len = torch.zeros(N).long()
@@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
print(metric.get_metric())
class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed