|
|
@@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase): |
|
|
|
b = torch.empty(3, dtype=torch.long).random_(5) |
|
|
|
ans = ce({"my_predict": a}, {"my_truth": b}) |
|
|
|
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) |
|
|
|
|
|
|
|
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1) |
|
|
|
a = torch.randn(3, 4, 3) |
|
|
|
b = torch.randint(3, (3, 3)) |
|
|
|
ans = ce({"my_predict": a}, {"my_truth": b}) |
|
|
|
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) |
|
|
|
|
|
|
|
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) |
|
|
|
a = torch.randn(3, 4, 3) |
|
|
|
b = torch.randint(3, (3, 4)) |
|
|
|
ans = ce({"my_predict": a}, {"my_truth": b}) |
|
|
|
self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b)) |
|
|
|
|
|
|
|
def test_BCELoss(self): |
|
|
|
bce = loss.BCELoss(pred="my_predict", target="my_truth") |
|
|
|