Browse Source

Update test_loss

tags/v0.2.0
FFTYYY 6 years ago
parent
commit
07fb61efdc
1 changed files with 18 additions and 3 deletions
  1. +18
    -3
      test/core/test_loss.py

+ 18
- 3
test/core/test_loss.py View File

@@ -5,7 +5,6 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.field import TextField, LabelField from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance

from fastNLP.core.optimizer import Optimizer from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.models.sequence_modeling import SeqLabeling
@@ -51,6 +50,8 @@ class TestLoss(unittest.TestCase):
print ("loss = %f" % (los)) print ("loss = %f" % (los))
print ("r = %f" % (r)) print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self): def test_case_2(self):
#验证squash()的正确性 #验证squash()的正确性
print ("----------------------------------") print ("----------------------------------")
@@ -82,12 +83,14 @@ class TestLoss(unittest.TestCase):


y = tc.log(y) y = tc.log(y)
los = loss_func(y , gy) los = loss_func(y , gy)
print ("loss = %f" % (los))


r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
r /= 6 r /= 6
print ("loss = %f" % (los))
print ("r = %f" % (r)) print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_3(self): def test_case_3(self):
#验证pack_padded_sequence()的正确性 #验证pack_padded_sequence()的正确性
print ("----------------------------------") print ("----------------------------------")
@@ -130,6 +133,8 @@ class TestLoss(unittest.TestCase):
r /= 6 r /= 6
print ("r = %f" % (r)) print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_4(self): def test_case_4(self):
#验证unpad()的正确性 #验证unpad()的正确性
print ("----------------------------------") print ("----------------------------------")
@@ -169,6 +174,9 @@ class TestLoss(unittest.TestCase):
r /= 7 r /= 7
print ("r = %f" % (r)) print ("r = %f" % (r))



self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_5(self): def test_case_5(self):
#验证mask()和make_mask()的正确性 #验证mask()和make_mask()的正确性
print ("----------------------------------") print ("----------------------------------")
@@ -217,6 +225,10 @@ class TestLoss(unittest.TestCase):
r /= 8 r /= 8
print ("r = %f" % (r)) print ("r = %f" % (r))



self.assertEqual(int(los * 1000), int(r * 1000))
self.assertEqual(int(los2 * 1000), int(r * 1000))

def test_case_6(self): def test_case_6(self):
#验证unpad_mask()的正确性 #验证unpad_mask()的正确性
print ("----------------------------------") print ("----------------------------------")
@@ -256,6 +268,8 @@ class TestLoss(unittest.TestCase):
r /= 7 r /= 7
print ("r = %f" % (r)) print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_7(self): def test_case_7(self):
#验证一些其他东西 #验证一些其他东西
print ("----------------------------------") print ("----------------------------------")
@@ -295,6 +309,7 @@ class TestLoss(unittest.TestCase):
r = - log(.3) - log(.5) - log(.3) r = - log(.3) - log(.5) - log(.3)
r /= 3 r /= 3
print ("r = %f" % (r)) print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000))


if __name__ == "__main__": if __name__ == "__main__":
unittest.main()
unittest.main()

Loading…
Cancel
Save