|
|
@@ -300,3 +300,22 @@ class TestLoss_v2(unittest.TestCase): |
|
|
|
b = torch.tensor([1, 0, 4]) |
|
|
|
ans = l1({"my_predict": a}, {"my_truth": b}) |
|
|
|
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) |
|
|
|
|
|
|
|
class TestLosserError(unittest.TestCase): |
|
|
|
def test_losser1(self): |
|
|
|
# (1) only input, targets passed |
|
|
|
pred_dict = {"pred": torch.zeros(4, 3)} |
|
|
|
target_dict = {'target': torch.zeros(4).long()} |
|
|
|
los = loss.CrossEntropyLoss() |
|
|
|
|
|
|
|
print(los(pred_dict=pred_dict, target_dict=target_dict)) |
|
|
|
|
|
|
|
# |
|
|
|
def test_AccuracyMetric2(self): |
|
|
|
# (2) with corrupted size |
|
|
|
pred_dict = {"pred": torch.zeros(16, 3, 4)} |
|
|
|
target_dict = {'target': torch.zeros(16, 3).long()} |
|
|
|
los = loss.CrossEntropyLoss() |
|
|
|
|
|
|
|
print(los(pred_dict=pred_dict, target_dict=target_dict)) |
|
|
|
|