You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss.py 7.3 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import math
  2. import unittest
  3. import torch as tc
  4. import torch.nn.functional as F
  5. import fastNLP.core.losses as loss
  6. class TestLoss(unittest.TestCase):
  7. def test_case_1(self):
  8. #验证nllloss的原理
  9. print (".----------------------------------")
  10. # loss_func = loss.Loss("nll")
  11. print(callable(tc.nn.NLLLoss))
  12. loss_func = loss.NewLoss(F.nll_loss)
  13. nll_loss = loss.NLLLoss()
  14. #pdb.set_trace()
  15. y = tc.Tensor(
  16. [
  17. [.3,.4,.3],
  18. [.5,.3,.2],
  19. [.3,.6,.1],
  20. ]
  21. )
  22. gy = tc.LongTensor(
  23. [
  24. 0,
  25. 1,
  26. 2,
  27. ]
  28. )
  29. y = tc.log(y)
  30. los = loss_func({'input': y}, {'target': gy})
  31. losses = nll_loss({'input': y}, {'target': gy})
  32. r = -math.log(.3) - math.log(.3) - math.log(.1)
  33. r /= 3
  34. print ("loss = %f" % (los))
  35. print ("r = %f" % (r))
  36. print ("nll_loss = %f" % (losses))
  37. self.assertEqual(int(los * 1000), int(r * 1000))
  38. def _test_case_2(self):
  39. #验证squash()的正确性
  40. print ("----------------------------------")
  41. log = math.log
  42. loss_func = loss.Loss("nll")
  43. #pdb.set_trace()
  44. y = tc.Tensor(
  45. [
  46. [[.3,.4,.3],[.3,.4,.3],],
  47. [[.5,.3,.2],[.1,.2,.7],],
  48. [[.3,.6,.1],[.2,.1,.7],],
  49. ]
  50. )
  51. gy = tc.LongTensor(
  52. [
  53. [0,2],
  54. [1,2],
  55. [2,1],
  56. ]
  57. )
  58. #pdb.set_trace()
  59. y = tc.log(y)
  60. #los = loss_func({'input': y}, {'target': gy})
  61. los = loss_func(y, gy)
  62. print ("loss = %f" % (los))
  63. r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
  64. r /= 6
  65. print ("r = %f" % (r))
  66. self.assertEqual(int(los * 1000), int(r * 1000))
  67. def test_case_3(self):
  68. #验证pack_padded_sequence()的正确性
  69. print ("----------------------------------")
  70. log = math.log
  71. #loss_func = loss.Loss("nll")
  72. loss_func = loss.NLLLoss()
  73. #pdb.set_trace()
  74. y = tc.Tensor(
  75. [
  76. [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],],
  77. [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],],
  78. [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],],
  79. ]
  80. )
  81. gy = tc.LongTensor(
  82. [
  83. [0,2,1,],
  84. [1,2,0,],
  85. [2,0,0,],
  86. ]
  87. )
  88. lens = [3,2,1]
  89. #pdb.set_trace()
  90. y = tc.log(y)
  91. yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data
  92. gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data
  93. los = loss_func({'input': yy}, {'target': gyy})
  94. print ("loss = %f" % (los))
  95. r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
  96. r /= 6
  97. print ("r = %f" % (r))
  98. self.assertEqual(int(los * 1000), int(r * 1000))
  99. def test_case_4(self):
  100. #验证unpad()的正确性
  101. print ("----------------------------------")
  102. log = math.log
  103. #pdb.set_trace()
  104. y = tc.Tensor(
  105. [
  106. [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
  107. [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
  108. [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
  109. ]
  110. )
  111. gy = tc.LongTensor(
  112. [
  113. [0,2,1,2,],
  114. [1,2,0,0,],
  115. [2,0,0,0,],
  116. ]
  117. )
  118. lens = [4,2,1]
  119. #pdb.set_trace()
  120. y = tc.log(y)
  121. loss_func = loss.Loss("nll" , pre_pro = ["unpad"])
  122. los = loss_func(y , gy , lens = lens)
  123. print ("loss = %f" % (los))
  124. r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
  125. r /= 7
  126. print ("r = %f" % (r))
  127. self.assertEqual(int(los * 1000), int(r * 1000))
  128. def test_case_5(self):
  129. #验证mask()和make_mask()的正确性
  130. print ("----------------------------------")
  131. log = math.log
  132. #pdb.set_trace()
  133. y = tc.Tensor(
  134. [
  135. [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
  136. [[.5,.4,.1],[.3,.2,.5],[.4,.5,.1,],[.6,.1,.3,],],
  137. [[.3,.6,.1],[.3,.2,.5],[.0,.0,.0,],[.0,.0,.0,],],
  138. ]
  139. )
  140. gy = tc.LongTensor(
  141. [
  142. [1,2,0,0,],
  143. [0,2,1,2,],
  144. [2,1,0,0,],
  145. ]
  146. )
  147. mask = tc.ByteTensor(
  148. [
  149. [1,1,0,0,],
  150. [1,1,1,1,],
  151. [1,1,0,0,],
  152. ]
  153. )
  154. y = tc.log(y)
  155. lens = [2,4,2]
  156. loss_func = loss.Loss("nll" , pre_pro = ["mask"])
  157. los = loss_func(y , gy , mask = mask)
  158. print ("loss = %f" % (los))
  159. los2 = loss_func(y , gy , mask = loss.make_mask(lens,gy.size()[-1]))
  160. print ("loss2 = %f" % (los2))
  161. r = -log(.3) -log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2)
  162. r /= 8
  163. print ("r = %f" % (r))
  164. self.assertEqual(int(los * 1000), int(r * 1000))
  165. self.assertEqual(int(los2 * 1000), int(r * 1000))
  166. def test_case_6(self):
  167. #验证unpad_mask()的正确性
  168. print ("----------------------------------")
  169. log = math.log
  170. #pdb.set_trace()
  171. y = tc.Tensor(
  172. [
  173. [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
  174. [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
  175. [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
  176. ]
  177. )
  178. gy = tc.LongTensor(
  179. [
  180. [0,2,1,2,],
  181. [1,2,0,0,],
  182. [2,0,0,0,],
  183. ]
  184. )
  185. lens = [4,2,1]
  186. #pdb.set_trace()
  187. y = tc.log(y)
  188. loss_func = loss.Loss("nll" , pre_pro = ["unpad_mask"])
  189. los = loss_func(y , gy , lens = lens)
  190. print ("loss = %f" % (los))
  191. r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
  192. r /= 7
  193. print ("r = %f" % (r))
  194. self.assertEqual(int(los * 1000), int(r * 1000))
  195. def test_case_7(self):
  196. #验证一些其他东西
  197. print ("----------------------------------")
  198. log = math.log
  199. #pdb.set_trace()
  200. y = tc.Tensor(
  201. [
  202. [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
  203. [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
  204. [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
  205. ]
  206. )
  207. gy = tc.LongTensor(
  208. [
  209. [0,2,1,2,],
  210. [1,2,0,0,],
  211. [2,0,0,0,],
  212. ]
  213. )
  214. lens = [4,2,1]
  215. #pdb.set_trace()
  216. y = tc.log(y)
  217. loss_func = loss.Loss("nll" , pre_pro = [] , weight = tc.Tensor([1,1,0]))
  218. loss_func.add_pre_pro("unpad_mask")
  219. los = loss_func(y , gy , lens = lens)
  220. print ("loss = %f" % (los))
  221. r = - log(.3) - log(.5) - log(.3)
  222. r /= 3
  223. print ("r = %f" % (r))
  224. self.assertEqual(int(los * 1000), int(r * 1000))
  225. def test_case_8(self):
  226. def func(a, b):
  227. import torch.nn.functional as F
  228. return F.cross_entropy(a, b)
  229. def func2(a, truth):
  230. return func(a, truth)
  231. def func3(predict, truth):
  232. return func(predict, truth)
  233. def func4(a, b, c=2):
  234. return (a + b) * c
  235. def func6(a, b, **kwargs):
  236. c = kwargs['c']
  237. return (a + b) * c
  238. import torch
  239. from fastNLP.core.losses import LossBase, NewLoss
  240. get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'})
  241. predict = torch.randn(5, 3)
  242. truth = torch.LongTensor([1, 0, 1, 2, 1])
  243. loss1 = get_loss({'predict': predict}, {'truth': truth})
  244. get_loss_2 = NewLoss(func2, {'a': 'predict'})
  245. loss2 = get_loss_2({'predict': predict}, {'truth': truth})
  246. get_loss_3 = NewLoss(func3)
  247. loss3 = get_loss_3({'predict': predict}, {'truth': truth})
  248. print(loss1, loss2, loss3)
  249. assert loss1 == loss2 and loss1 == loss3
  250. get_loss_4 = NewLoss(func4)
  251. loss4 = get_loss_4({'a': 1, 'b': 3}, {})
  252. print(loss4)
  253. assert loss4 == (1 + 3) * 2
  254. get_loss_5 = NewLoss(func4)
  255. loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
  256. print(loss5)
  257. assert loss5 == (1 + 3) * 4
  258. get_loss_6 = NewLoss(func6)
  259. loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
  260. print(loss6)
  261. assert loss6 == (1 + 3) * 4
  262. get_loss_7 = NewLoss(func6, c='cc')
  263. loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
  264. print(loss7)
  265. assert loss7 == (1 + 3) * 4
  266. if __name__ == "__main__":
  267. unittest.main()