|
|
@@ -6,67 +6,123 @@ import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
class TestAccuracyMetric(unittest.TestCase): |
|
|
|
def test_AccuracyMetric1(self): |
|
|
|
# (1) only input, targets passed |
|
|
|
output_dict = {"pred": torch.zeros(4, 3)} |
|
|
|
target_dict = {'target': torch.zeros(4)} |
|
|
|
metric = AccuracyMetric() |
|
|
|
# def test_AccuracyMetric1(self): |
|
|
|
# # (1) only input, targets passed |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3)} |
|
|
|
# target_dict = {'target': torch.zeros(4)} |
|
|
|
# metric = AccuracyMetric() |
|
|
|
# |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# print(metric.get_metric()) |
|
|
|
# |
|
|
|
# def test_AccuracyMetric2(self): |
|
|
|
# # (2) with corrupted size |
|
|
|
# try: |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4)} |
|
|
|
# metric = AccuracyMetric() |
|
|
|
# |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# print(metric.get_metric()) |
|
|
|
# except Exception as e: |
|
|
|
# print(e) |
|
|
|
# return |
|
|
|
# self.assertTrue(True, False), "No exception catches." |
|
|
|
# |
|
|
|
# def test_AccuracyMetric3(self): |
|
|
|
# # (3) with check=False , the second batch is corrupted size |
|
|
|
# try: |
|
|
|
# metric = AccuracyMetric() |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# |
|
|
|
# print(metric.get_metric()) |
|
|
|
# except Exception as e: |
|
|
|
# print(e) |
|
|
|
# return |
|
|
|
# self.assertTrue(True, False), "No exception catches." |
|
|
|
# |
|
|
|
# def test_AccuracyMetric4(self): |
|
|
|
# # (4) with check=True , the second batch is corrupted size |
|
|
|
# try: |
|
|
|
# metric = AccuracyMetric() |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) |
|
|
|
# |
|
|
|
# print(metric.get_metric()) |
|
|
|
# |
|
|
|
# except Exception as e: |
|
|
|
# print(e) |
|
|
|
# return |
|
|
|
# self.assertTrue(True, False), "No exception catches." |
|
|
|
# |
|
|
|
# def test_AccuaryMetric5(self): |
|
|
|
# # (5) check reset |
|
|
|
# metric = AccuracyMetric() |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
# |
|
|
|
# pred_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'target': torch.zeros(4, 3)+1} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# self.assertDictEqual(metric.get_metric(), {'acc':0}) |
|
|
|
# |
|
|
|
# def test_AccuaryMetric6(self): |
|
|
|
# # (6) check numpy array is not acceptable |
|
|
|
# try: |
|
|
|
# metric = AccuracyMetric() |
|
|
|
# pred_dict = {"pred": np.zeros((4, 3, 2))} |
|
|
|
# target_dict = {'target': np.zeros((4, 3))} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
# except Exception as e: |
|
|
|
# print(e) |
|
|
|
# return |
|
|
|
# self.assertTrue(True, False), "No exception catches." |
|
|
|
|
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
print(metric.get_metric()) |
|
|
|
# def test_AccuaryMetric7(self): |
|
|
|
# # (7) check map, match |
|
|
|
# metric = AccuracyMetric(pred='predictions', target='targets') |
|
|
|
# pred_dict = {"predictions": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'targets': torch.zeros(4, 3)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
# |
|
|
|
# def test_AccuaryMetric8(self): |
|
|
|
# # (8) check map, does not match |
|
|
|
# try: |
|
|
|
# metric = AccuracyMetric(pred='predictions', target='targets') |
|
|
|
# pred_dict = {"prediction": torch.zeros(4, 3, 2)} |
|
|
|
# target_dict = {'targets': torch.zeros(4, 3)} |
|
|
|
# metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
# except Exception as e: |
|
|
|
# print(e) |
|
|
|
# return |
|
|
|
# self.assertTrue(True, False), "No exception catches." |
|
|
|
|
|
|
|
def test_AccuracyMetric2(self): |
|
|
|
# (2) with corrupted size |
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4)} |
|
|
|
metric = AccuracyMetric() |
|
|
|
def test_AccuaryMetric9(self): |
|
|
|
# (9) check map, include unused |
|
|
|
try: |
|
|
|
metric = AccuracyMetric(pred='predictions', target='targets') |
|
|
|
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} |
|
|
|
target_dict = {'targets': torch.zeros(4, 3)} |
|
|
|
metric(pred_dict=pred_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
except Exception as e: |
|
|
|
print(e) |
|
|
|
return |
|
|
|
self.assertTrue(True, False), "No exception catches." |
|
|
|
|
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
print(metric.get_metric()) |
|
|
|
|
|
|
|
def test_AccuracyMetric3(self): |
|
|
|
# (3) with check=False , the second batch is corrupted size |
|
|
|
metric = AccuracyMetric() |
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
|
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4)} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
|
|
|
|
print(metric.get_metric()) |
|
|
|
|
|
|
|
def test_AccuracyMetric4(self): |
|
|
|
# (4) with check=True , the second batch is corrupted size |
|
|
|
metric = AccuracyMetric() |
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
|
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4)} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict, check=True) |
|
|
|
|
|
|
|
print(metric.get_metric()) |
|
|
|
|
|
|
|
def test_AccuaryMetric5(self): |
|
|
|
# (5) check reset |
|
|
|
metric = AccuracyMetric() |
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3)} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 1}) |
|
|
|
|
|
|
|
output_dict = {"pred": torch.zeros(4, 3, 2)} |
|
|
|
target_dict = {'target': torch.zeros(4, 3)+1} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc':0}) |
|
|
|
|
|
|
|
def test_AccuaryMetric6(self): |
|
|
|
# (6) check numpy array is not acceptable |
|
|
|
metric = AccuracyMetric() |
|
|
|
output_dict = {"pred": np.zeros((4, 3, 2))} |
|
|
|
target_dict = {'target': np.zeros((4, 3))} |
|
|
|
metric(output_dict=output_dict, target_dict=target_dict) |
|
|
|
self.assertDictEqual(metric.get_metric(), {'acc': 1}) |