| @@ -5,7 +5,6 @@ from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| @@ -51,6 +50,8 @@ class TestLoss(unittest.TestCase): | |||
| print ("loss = %f" % (los)) | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| def test_case_2(self): | |||
| #验证squash()的正确性 | |||
| print ("----------------------------------") | |||
| @@ -82,12 +83,14 @@ class TestLoss(unittest.TestCase): | |||
| y = tc.log(y) | |||
| los = loss_func(y , gy) | |||
| print ("loss = %f" % (los)) | |||
| r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | |||
| r /= 6 | |||
| print ("loss = %f" % (los)) | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| def test_case_3(self): | |||
| #验证pack_padded_sequence()的正确性 | |||
| print ("----------------------------------") | |||
| @@ -130,6 +133,8 @@ class TestLoss(unittest.TestCase): | |||
| r /= 6 | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| def test_case_4(self): | |||
| #验证unpad()的正确性 | |||
| print ("----------------------------------") | |||
| @@ -169,6 +174,9 @@ class TestLoss(unittest.TestCase): | |||
| r /= 7 | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| def test_case_5(self): | |||
| #验证mask()和make_mask()的正确性 | |||
| print ("----------------------------------") | |||
| @@ -217,6 +225,10 @@ class TestLoss(unittest.TestCase): | |||
| r /= 8 | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| self.assertEqual(int(los2 * 1000), int(r * 1000)) | |||
| def test_case_6(self): | |||
| #验证unpad_mask()的正确性 | |||
| print ("----------------------------------") | |||
| @@ -256,6 +268,8 @@ class TestLoss(unittest.TestCase): | |||
| r /= 7 | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| def test_case_7(self): | |||
| #验证一些其他东西 | |||
| print ("----------------------------------") | |||
| @@ -295,6 +309,7 @@ class TestLoss(unittest.TestCase): | |||
| r = - log(.3) - log(.5) - log(.3) | |||
| r /= 3 | |||
| print ("r = %f" % (r)) | |||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| unittest.main() | |||