From 23305af7335148669fefb74c217648d439f32182 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Thu, 13 Dec 2018 10:55:48 +0800 Subject: [PATCH] Fix failed tests. --- test/core/test_metrics.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index c6267664..125b9156 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -52,28 +52,24 @@ class TestAccuracyMetric(unittest.TestCase): def test_AccuaryMetric4(self): # (5) check reset metric = AccuracyMetric() - pred_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)} - metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc': 1}) - - pred_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3) + 1} + pred_dict = {"pred": torch.randn(4, 3, 2)} + target_dict = {'target': torch.ones(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc': 0}) + ans = torch.argmax(pred_dict["pred"], dim=2).to(target_dict["target"]) == target_dict["target"] + res = metric.get_metric() + self.assertTrue(isinstance(res, dict)) + self.assertTrue("acc" in res) + self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) def test_AccuaryMetric5(self): # (5) check reset metric = AccuracyMetric() - pred_dict = {"pred": torch.zeros(4, 3, 2)} + pred_dict = {"pred": torch.randn(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) - - pred_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3) + 1} - metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) + res = metric.get_metric(reset=False) + ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() + self.assertAlmostEqual(res["acc"], float(ans), places=4) def test_AccuaryMetric6(self): # (6) check numpy array is not acceptable @@ -90,10 +86,12 @@ class TestAccuracyMetric(unittest.TestCase): def test_AccuaryMetric7(self): # (7) check map, match metric = AccuracyMetric(pred='predictions', target='targets') - pred_dict = {"predictions": torch.zeros(4, 3, 2)} + pred_dict = {"predictions": torch.randn(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc': 1}) + res = metric.get_metric() + ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() + self.assertAlmostEqual(res["acc"], float(ans), places=4) def test_AccuaryMetric8(self): # (8) check map, does not match. use stop_fast_param to stop fast param map