|
|
@@ -300,3 +300,13 @@ class TestLoss_v2(unittest.TestCase): |
|
|
|
b = torch.tensor([1, 0, 4]) |
|
|
|
ans = l1({"my_predict": a}, {"my_truth": b}) |
|
|
|
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) |
|
|
|
|
|
|
|
def test_check_error(self): |
|
|
|
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") |
|
|
|
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) |
|
|
|
b = torch.tensor([1, 0, 4]) |
|
|
|
with self.assertRaises(Exception): |
|
|
|
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) |
|
|
|
|
|
|
|
with self.assertRaises(Exception): |
|
|
|
ans = l1({"my_predict": a}, {"truth": b, "my": a}) |