| @@ -5,7 +5,6 @@ from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
| from fastNLP.core.field import TextField, LabelField | from fastNLP.core.field import TextField, LabelField | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| @@ -51,6 +50,8 @@ class TestLoss(unittest.TestCase): | |||||
| print ("loss = %f" % (los)) | print ("loss = %f" % (los)) | ||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| def test_case_2(self): | def test_case_2(self): | ||||
| #验证squash()的正确性 | #验证squash()的正确性 | ||||
| print ("----------------------------------") | print ("----------------------------------") | ||||
| @@ -82,12 +83,14 @@ class TestLoss(unittest.TestCase): | |||||
| y = tc.log(y) | y = tc.log(y) | ||||
| los = loss_func(y , gy) | los = loss_func(y , gy) | ||||
| print ("loss = %f" % (los)) | |||||
| r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | ||||
| r /= 6 | r /= 6 | ||||
| print ("loss = %f" % (los)) | |||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| def test_case_3(self): | def test_case_3(self): | ||||
| #验证pack_padded_sequence()的正确性 | #验证pack_padded_sequence()的正确性 | ||||
| print ("----------------------------------") | print ("----------------------------------") | ||||
| @@ -130,6 +133,8 @@ class TestLoss(unittest.TestCase): | |||||
| r /= 6 | r /= 6 | ||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| def test_case_4(self): | def test_case_4(self): | ||||
| #验证unpad()的正确性 | #验证unpad()的正确性 | ||||
| print ("----------------------------------") | print ("----------------------------------") | ||||
| @@ -169,6 +174,9 @@ class TestLoss(unittest.TestCase): | |||||
| r /= 7 | r /= 7 | ||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| def test_case_5(self): | def test_case_5(self): | ||||
| #验证mask()和make_mask()的正确性 | #验证mask()和make_mask()的正确性 | ||||
| print ("----------------------------------") | print ("----------------------------------") | ||||
| @@ -217,6 +225,10 @@ class TestLoss(unittest.TestCase): | |||||
| r /= 8 | r /= 8 | ||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| self.assertEqual(int(los2 * 1000), int(r * 1000)) | |||||
| def test_case_6(self): | def test_case_6(self): | ||||
| #验证unpad_mask()的正确性 | #验证unpad_mask()的正确性 | ||||
| print ("----------------------------------") | print ("----------------------------------") | ||||
| @@ -256,6 +268,8 @@ class TestLoss(unittest.TestCase): | |||||
| r /= 7 | r /= 7 | ||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| def test_case_7(self): | def test_case_7(self): | ||||
| #验证一些其他东西 | #验证一些其他东西 | ||||
| print ("----------------------------------") | print ("----------------------------------") | ||||
| @@ -295,6 +309,7 @@ class TestLoss(unittest.TestCase): | |||||
| r = - log(.3) - log(.5) - log(.3) | r = - log(.3) - log(.5) - log(.3) | ||||
| r /= 3 | r /= 3 | ||||
| print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
| self.assertEqual(int(los * 1000), int(r * 1000)) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| unittest.main() | |||||
| unittest.main() | |||||