diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 9ba8159f..a57e6542 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase): b = torch.empty(3, dtype=torch.long).random_(5) ans = ce({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1) + a = torch.randn(3, 4, 3) + b = torch.randint(3, (3, 3)) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) + a = torch.randn(3, 4, 3) + b = torch.randint(3, (3, 4)) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b)) def test_BCELoss(self): bce = loss.BCELoss(pred="my_predict", target="my_truth")