Browse Source

解决CrossEntropyLoss测试因为数值问题无法通过测试的问题

tags/v0.5.5
yh 5 years ago
parent
commit
f887da12a1
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      test/core/test_loss.py

+ 2
- 2
test/core/test_loss.py View File

@@ -18,13 +18,13 @@ class TestLoss(unittest.TestCase):
a = torch.randn(3, 4, 3) a = torch.randn(3, 4, 3)
b = torch.randint(3, (3, 3)) b = torch.randint(3, (3, 3))
ans = ce({"my_predict": a}, {"my_truth": b}) ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))
self.assertAlmostEqual(ans.item(), torch.nn.functional.cross_entropy(a, b).item(), places=4)


ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2)
a = torch.randn(3, 4, 3) a = torch.randn(3, 4, 3)
b = torch.randint(3, (3, 4)) b = torch.randint(3, (3, 4))
ans = ce({"my_predict": a}, {"my_truth": b}) ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b))
self.assertAlmostEqual(ans.item(), torch.nn.functional.cross_entropy(a.transpose(1, 2), b).item(), places=4)
def test_BCELoss(self): def test_BCELoss(self):
bce = loss.BCELoss(pred="my_predict", target="my_truth") bce = loss.BCELoss(pred="my_predict", target="my_truth")


Loading…
Cancel
Save