Browse Source

test loss

tags/v0.2.0^2
yh 6 years ago
parent
commit
62c63f159a
1 changed files with 19 additions and 0 deletions
  1. +19
    -0
      test/core/test_loss.py

+ 19
- 0
test/core/test_loss.py View File

@@ -300,3 +300,22 @@ class TestLoss_v2(unittest.TestCase):
b = torch.tensor([1, 0, 4])
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))

class TestLosserError(unittest.TestCase):
def test_losser1(self):
# (1) only input, targets passed
pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4).long()}
los = loss.CrossEntropyLoss()

print(los(pred_dict=pred_dict, target_dict=target_dict))

#
def test_AccuracyMetric2(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3, 4)}
target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss()

print(los(pred_dict=pred_dict, target_dict=target_dict))


Loading…
Cancel
Save