From 9c0190fbd82f4c50956d41e32a6370eb93317add Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 19:28:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AF=B9CrossEntropyLoss?= =?UTF-8?q?=E4=B8=ADclass=5Fin=5Fdim=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/core/test_loss.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 9ba8159f..a57e6542 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase): b = torch.empty(3, dtype=torch.long).random_(5) ans = ce({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1) + a = torch.randn(3, 4, 3) + b = torch.randint(3, (3, 3)) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) + a = torch.randn(3, 4, 3) + b = torch.randint(3, (3, 4)) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b)) def test_BCELoss(self): bce = loss.BCELoss(pred="my_predict", target="my_truth")