|
|
@@ -52,28 +52,24 @@ class TestAccuracyMetric(unittest.TestCase): |
|
|
|
def test_AccuaryMetric4(self): |
|
|
|
# (5) check reset |
|
|
|
metric = AccuracyMetric() |
|
|
|
pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
|
|
|
|
pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3) + 1} |
|
|
|
pred_dict = {"pred": torch.randn(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.ones(4, 3)} |
|
|
|
metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 0}) |
|
|
|
ans = torch.argmax(pred_dict["pred"], dim=2).to(target_dict["target"]) == target_dict["target"] |
|
|
|
res = metric.get_metric() |
|
|
|
self.assertTrue(isinstance(res, dict)) |
|
|
|
self.assertTrue("acc" in res) |
|
|
|
self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) |
|
|
|
|
|
|
|
def test_AccuaryMetric5(self): |
|
|
|
# (5) check reset |
|
|
|
metric = AccuracyMetric() |
|
|
|
pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
pred_dict = {"pred": torch.randn(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) |
|
|
|
|
|
|
|
pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3) + 1} |
|
|
|
metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) |
|
|
|
res = metric.get_metric(reset=False) |
|
|
|
ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() |
|
|
|
self.assertAlmostEqual(res["acc"], float(ans), places=4) |
|
|
|
|
|
|
|
def test_AccuaryMetric6(self): |
|
|
|
# (6) check numpy array is not acceptable |
|
|
@@ -90,10 +86,12 @@ class TestAccuracyMetric(unittest.TestCase): |
|
|
|
def test_AccuaryMetric7(self): |
|
|
|
# (7) check map, match |
|
|
|
metric = AccuracyMetric(pred='predictions', target='targets') |
|
|
|
pred_dict = {"predictions": torch.zeros(4, 3, 2)} |
|
|
|
pred_dict = {"predictions": torch.randn(4, 3, 2)} |
|
|
|
target_dict = {'targets': torch.zeros(4, 3)} |
|
|
|
metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
res = metric.get_metric() |
|
|
|
ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() |
|
|
|
self.assertAlmostEqual(res["acc"], float(ans), places=4) |
|
|
|
|
|
|
|
def test_AccuaryMetric8(self): |
|
|
|
# (8) check map, does not match. use stop_fast_param to stop fast param map |
|
|
|