Browse Source

新增对CrossEntropyLoss中class_in_dim的测试

tags/v0.5.5
yh 5 years ago
parent
commit
9c0190fbd8
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      test/core/test_loss.py

+ 12
- 0
test/core/test_loss.py View File

@@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase):
b = torch.empty(3, dtype=torch.long).random_(5) b = torch.empty(3, dtype=torch.long).random_(5)
ans = ce({"my_predict": a}, {"my_truth": b}) ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))

ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1)
a = torch.randn(3, 4, 3)
b = torch.randint(3, (3, 3))
ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))

ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2)
a = torch.randn(3, 4, 3)
b = torch.randint(3, (3, 4))
ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b))
def test_BCELoss(self): def test_BCELoss(self):
bce = loss.BCELoss(pred="my_predict", target="my_truth") bce = loss.BCELoss(pred="my_predict", target="my_truth")


Loading…
Cancel
Save