From 07fb61efdc5940e9c9b7162c4c05c667848120d5 Mon Sep 17 00:00:00 2001 From: FFTYYY <1004473299@qq.com> Date: Sat, 10 Nov 2018 23:21:26 +0800 Subject: [PATCH] Update test_loss --- test/core/test_loss.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/test/core/test_loss.py b/test/core/test_loss.py index d6b43fc1..d7cafc13 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -5,7 +5,6 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance - from fastNLP.core.optimizer import Optimizer from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.models.sequence_modeling import SeqLabeling @@ -51,6 +50,8 @@ class TestLoss(unittest.TestCase): print ("loss = %f" % (los)) print ("r = %f" % (r)) + self.assertEqual(int(los * 1000), int(r * 1000)) + def test_case_2(self): #验证squash()的正确性 print ("----------------------------------") @@ -82,12 +83,14 @@ class TestLoss(unittest.TestCase): y = tc.log(y) los = loss_func(y , gy) + print ("loss = %f" % (los)) r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) r /= 6 - print ("loss = %f" % (los)) print ("r = %f" % (r)) + self.assertEqual(int(los * 1000), int(r * 1000)) + def test_case_3(self): #验证pack_padded_sequence()的正确性 print ("----------------------------------") @@ -130,6 +133,8 @@ class TestLoss(unittest.TestCase): r /= 6 print ("r = %f" % (r)) + self.assertEqual(int(los * 1000), int(r * 1000)) + def test_case_4(self): #验证unpad()的正确性 print ("----------------------------------") @@ -169,6 +174,9 @@ class TestLoss(unittest.TestCase): r /= 7 print ("r = %f" % (r)) + + self.assertEqual(int(los * 1000), int(r * 1000)) + def test_case_5(self): #验证mask()和make_mask()的正确性 print ("----------------------------------") @@ -217,6 +225,10 @@ class TestLoss(unittest.TestCase): r /= 8 print ("r = %f" % (r)) + + self.assertEqual(int(los * 1000), int(r * 1000)) + self.assertEqual(int(los2 * 1000), int(r * 1000)) + def test_case_6(self): #验证unpad_mask()的正确性 print ("----------------------------------") @@ -256,6 +268,8 @@ class TestLoss(unittest.TestCase): r /= 7 print ("r = %f" % (r)) + self.assertEqual(int(los * 1000), int(r * 1000)) + def test_case_7(self): #验证一些其他东西 print ("----------------------------------") @@ -295,6 +309,7 @@ class TestLoss(unittest.TestCase): r = - log(.3) - log(.5) - log(.3) r /= 3 print ("r = %f" % (r)) + self.assertEqual(int(los * 1000), int(r * 1000)) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()