| @@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase): | |||||
| b = torch.empty(3, dtype=torch.long).random_(5) | b = torch.empty(3, dtype=torch.long).random_(5) | ||||
| ans = ce({"my_predict": a}, {"my_truth": b}) | ans = ce({"my_predict": a}, {"my_truth": b}) | ||||
| self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | ||||
| ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1) | |||||
| a = torch.randn(3, 4, 3) | |||||
| b = torch.randint(3, (3, 3)) | |||||
| ans = ce({"my_predict": a}, {"my_truth": b}) | |||||
| self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | |||||
| ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) | |||||
| a = torch.randn(3, 4, 3) | |||||
| b = torch.randint(3, (3, 4)) | |||||
| ans = ce({"my_predict": a}, {"my_truth": b}) | |||||
| self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b)) | |||||
| def test_BCELoss(self): | def test_BCELoss(self): | ||||
| bce = loss.BCELoss(pred="my_predict", target="my_truth") | bce = loss.BCELoss(pred="my_predict", target="my_truth") | ||||