diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 9b77d0a1..060aefb3 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -300,3 +300,22 @@ class TestLoss_v2(unittest.TestCase): b = torch.tensor([1, 0, 4]) ans = l1({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) + +class TestLosserError(unittest.TestCase): + def test_losser1(self): + # (1) only input, targets passed + pred_dict = {"pred": torch.zeros(4, 3)} + target_dict = {'target': torch.zeros(4).long()} + los = loss.CrossEntropyLoss() + + print(los(pred_dict=pred_dict, target_dict=target_dict)) + + # + def test_AccuracyMetric2(self): + # (2) with corrupted size + pred_dict = {"pred": torch.zeros(16, 3, 4)} + target_dict = {'target': torch.zeros(16, 3).long()} + los = loss.CrossEntropyLoss() + + print(los(pred_dict=pred_dict, target_dict=target_dict)) +