Browse Source

Update test_loss

tags/v0.2.0
FFTYYY 6 years ago
parent
commit
07fb61efdc
1 changed files with 18 additions and 3 deletions
  1. +18
    -3
      test/core/test_loss.py

+ 18
- 3
test/core/test_loss.py View File

@@ -5,7 +5,6 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance

from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling
@@ -51,6 +50,8 @@ class TestLoss(unittest.TestCase):
print ("loss = %f" % (los))
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self):
#验证squash()的正确性
print ("----------------------------------")
@@ -82,12 +83,14 @@ class TestLoss(unittest.TestCase):

y = tc.log(y)
los = loss_func(y , gy)
print ("loss = %f" % (los))

r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
r /= 6
print ("loss = %f" % (los))
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_3(self):
#验证pack_padded_sequence()的正确性
print ("----------------------------------")
@@ -130,6 +133,8 @@ class TestLoss(unittest.TestCase):
r /= 6
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_4(self):
#验证unpad()的正确性
print ("----------------------------------")
@@ -169,6 +174,9 @@ class TestLoss(unittest.TestCase):
r /= 7
print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_5(self):
#验证mask()和make_mask()的正确性
print ("----------------------------------")
@@ -217,6 +225,10 @@ class TestLoss(unittest.TestCase):
r /= 8
print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))
self.assertEqual(int(los2 * 1000), int(r * 1000))

def test_case_6(self):
#验证unpad_mask()的正确性
print ("----------------------------------")
@@ -256,6 +268,8 @@ class TestLoss(unittest.TestCase):
r /= 7
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_7(self):
#验证一些其他东西
print ("----------------------------------")
@@ -295,6 +309,7 @@ class TestLoss(unittest.TestCase):
r = - log(.3) - log(.5) - log(.3)
r /= 3
print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000))

if __name__ == "__main__":
unittest.main()
unittest.main()

Loading…
Cancel
Save