You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_distribution.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Test nn.Distribution.
  17. Including Normal Distribution and Bernoulli Distribution.
  18. """
  19. import pytest
  20. import numpy as np
  21. import mindspore.nn as nn
  22. from mindspore import dtype
  23. from mindspore import Tensor
  24. def test_normal_shape_errpr():
  25. """
  26. Invalid shapes.
  27. """
  28. with pytest.raises(ValueError):
  29. nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
  30. def test_no_arguments():
  31. """
  32. No args passed in during initialization.
  33. """
  34. n = nn.Normal()
  35. assert isinstance(n, nn.Distribution)
  36. b = nn.Bernoulli()
  37. assert isinstance(b, nn.Distribution)
  38. def test_with_arguments():
  39. """
  40. Args passed in during initialization.
  41. """
  42. n = nn.Normal([3.0], [4.0], dtype=dtype.float32)
  43. assert isinstance(n, nn.Distribution)
  44. b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32)
  45. assert isinstance(b, nn.Distribution)
  46. class NormalProb(nn.Cell):
  47. """
  48. Normal distribution: initialize with mean/sd.
  49. """
  50. def __init__(self):
  51. super(NormalProb, self).__init__()
  52. self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32)
  53. def construct(self, value):
  54. x = self.normal('prob', value)
  55. y = self.normal('log_prob', value)
  56. return x, y
  57. def test_normal_prob():
  58. """
  59. Test pdf/log_pdf: passing value through construct.
  60. """
  61. net = NormalProb()
  62. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  63. pdf, log_pdf = net(value)
  64. assert isinstance(pdf, Tensor)
  65. assert isinstance(log_pdf, Tensor)
  66. class NormalProb1(nn.Cell):
  67. """
  68. Normal distribution: initialize without mean/sd.
  69. """
  70. def __init__(self):
  71. super(NormalProb1, self).__init__()
  72. self.normal = nn.Normal()
  73. def construct(self, value, mean, sd):
  74. x = self.normal('prob', value, mean, sd)
  75. y = self.normal('log_prob', value, mean, sd)
  76. return x, y
  77. def test_normal_prob1():
  78. """
  79. Test pdf/logpdf: passing mean/sd, value through construct.
  80. """
  81. net = NormalProb1()
  82. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  83. mean = Tensor([0.0], dtype=dtype.float32)
  84. sd = Tensor([1.0], dtype=dtype.float32)
  85. pdf, log_pdf = net(value, mean, sd)
  86. assert isinstance(pdf, Tensor)
  87. assert isinstance(log_pdf, Tensor)
  88. class NormalProb2(nn.Cell):
  89. """
  90. Normal distribution: initialize with mean/sd.
  91. """
  92. def __init__(self):
  93. super(NormalProb2, self).__init__()
  94. self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32)
  95. def construct(self, value, mean, sd):
  96. x = self.normal('prob', value, mean, sd)
  97. y = self.normal('log_prob', value, mean, sd)
  98. return x, y
  99. def test_normal_prob2():
  100. """
  101. Test pdf/log_pdf: passing mean/sd through construct.
  102. Overwrite original mean/sd.
  103. """
  104. net = NormalProb2()
  105. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  106. mean = Tensor([0.0], dtype=dtype.float32)
  107. sd = Tensor([1.0], dtype=dtype.float32)
  108. pdf, log_pdf = net(value, mean, sd)
  109. assert isinstance(pdf, Tensor)
  110. assert isinstance(log_pdf, Tensor)
  111. class BernoulliProb(nn.Cell):
  112. """
  113. Bernoulli distribution: initialize with probs.
  114. """
  115. def __init__(self):
  116. super(BernoulliProb, self).__init__()
  117. self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32)
  118. def construct(self, value):
  119. return self.bernoulli('prob', value)
  120. class BernoulliLogProb(nn.Cell):
  121. """
  122. Bernoulli distribution: initialize with probs.
  123. """
  124. def __init__(self):
  125. super(BernoulliLogProb, self).__init__()
  126. self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32)
  127. def construct(self, value):
  128. return self.bernoulli('log_prob', value)
  129. def test_bernoulli_prob():
  130. """
  131. Test pmf/log_pmf: passing value through construct.
  132. """
  133. net = BernoulliProb()
  134. value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
  135. pmf = net(value)
  136. assert isinstance(pmf, Tensor)
  137. def test_bernoulli_log_prob():
  138. """
  139. Test pmf/log_pmf: passing value through construct.
  140. """
  141. net = BernoulliLogProb()
  142. value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
  143. log_pmf = net(value)
  144. assert isinstance(log_pmf, Tensor)
  145. class BernoulliProb1(nn.Cell):
  146. """
  147. Bernoulli distribution: initialize without probs.
  148. """
  149. def __init__(self):
  150. super(BernoulliProb1, self).__init__()
  151. self.bernoulli = nn.Bernoulli()
  152. def construct(self, value, probs):
  153. return self.bernoulli('prob', value, probs)
  154. class BernoulliLogProb1(nn.Cell):
  155. """
  156. Bernoulli distribution: initialize without probs.
  157. """
  158. def __init__(self):
  159. super(BernoulliLogProb1, self).__init__()
  160. self.bernoulli = nn.Bernoulli()
  161. def construct(self, value, probs):
  162. return self.bernoulli('log_prob', value, probs)
  163. def test_bernoulli_prob1():
  164. """
  165. Test pmf/log_pmf: passing probs through construct.
  166. """
  167. net = BernoulliProb1()
  168. value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
  169. probs = Tensor([0.3], dtype=dtype.float32)
  170. pmf = net(value, probs)
  171. assert isinstance(pmf, Tensor)
  172. def test_bernoulli_log_prob1():
  173. """
  174. Test pmf/log_pmf: passing probs through construct.
  175. """
  176. net = BernoulliLogProb1()
  177. value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
  178. probs = Tensor([0.3], dtype=dtype.float32)
  179. log_pmf = net(value, probs)
  180. assert isinstance(log_pmf, Tensor)
  181. class BernoulliProb2(nn.Cell):
  182. """
  183. Bernoulli distribution: initialize with probs.
  184. """
  185. def __init__(self):
  186. super(BernoulliProb2, self).__init__()
  187. self.bernoulli = nn.Bernoulli(0.5)
  188. def construct(self, value, probs):
  189. return self.bernoulli('prob', value, probs)
  190. class BernoulliLogProb2(nn.Cell):
  191. """
  192. Bernoulli distribution: initialize with probs.
  193. """
  194. def __init__(self):
  195. super(BernoulliLogProb2, self).__init__()
  196. self.bernoulli = nn.Bernoulli(0.5)
  197. def construct(self, value, probs):
  198. return self.bernoulli('log_prob', value, probs)
  199. def test_bernoulli_prob2():
  200. """
  201. Test pmf/log_pmf: passing probs/value through construct.
  202. Overwrite original probs.
  203. """
  204. net = BernoulliProb2()
  205. value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
  206. probs = Tensor([0.3], dtype=dtype.float32)
  207. pmf = net(value, probs)
  208. assert isinstance(pmf, Tensor)
  209. def test_bernoulli_log_prob2():
  210. """
  211. Test pmf/log_pmf: passing probs/value through construct.
  212. Overwrite original probs.
  213. """
  214. net = BernoulliLogProb2()
  215. value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
  216. probs = Tensor([0.3], dtype=dtype.float32)
  217. log_pmf = net(value, probs)
  218. assert isinstance(log_pmf, Tensor)
  219. class NormalKl(nn.Cell):
  220. """
  221. Test class: kl_loss of Normal distribution.
  222. """
  223. def __init__(self):
  224. super(NormalKl, self).__init__()
  225. self.n = nn.Normal(Tensor([3.0]), Tensor([4.0]), dtype=dtype.float32)
  226. def construct(self, x_, y_):
  227. return self.n('kl_loss', 'Normal', x_, y_)
  228. class BernoulliKl(nn.Cell):
  229. """
  230. Test class: kl_loss between Bernoulli distributions.
  231. """
  232. def __init__(self):
  233. super(BernoulliKl, self).__init__()
  234. self.b = nn.Bernoulli(0.7, dtype=dtype.int32)
  235. def construct(self, x_):
  236. return self.b('kl_loss', 'Bernoulli', x_)
  237. def test_kl():
  238. """
  239. Test kl_loss function.
  240. """
  241. nor_net = NormalKl()
  242. mean_b = np.array([1.0]).astype(np.float32)
  243. sd_b = np.array([1.0]).astype(np.float32)
  244. mean = Tensor(mean_b, dtype=dtype.float32)
  245. sd = Tensor(sd_b, dtype=dtype.float32)
  246. loss = nor_net(mean, sd)
  247. assert isinstance(loss, Tensor)
  248. ber_net = BernoulliKl()
  249. probs_b = Tensor([0.3], dtype=dtype.float32)
  250. loss = ber_net(probs_b)
  251. assert isinstance(loss, Tensor)
  252. class NormalKlNoArgs(nn.Cell):
  253. """
  254. Test class: kl_loss of Normal distribution.
  255. No args during initialization.
  256. """
  257. def __init__(self):
  258. super(NormalKlNoArgs, self).__init__()
  259. self.n = nn.Normal(dtype=dtype.float32)
  260. def construct(self, x_, y_, w_, v_):
  261. return self.n('kl_loss', 'Normal', x_, y_, w_, v_)
  262. class BernoulliKlNoArgs(nn.Cell):
  263. """
  264. Test class: kl_loss between Bernoulli distributions.
  265. No args during initialization.
  266. """
  267. def __init__(self):
  268. super(BernoulliKlNoArgs, self).__init__()
  269. self.b = nn.Bernoulli(dtype=dtype.int32)
  270. def construct(self, x_, y_):
  271. return self.b('kl_loss', 'Bernoulli', x_, y_)
  272. def test_kl_no_args():
  273. """
  274. Test kl_loss function.
  275. """
  276. nor_net = NormalKlNoArgs()
  277. mean_b = np.array([1.0]).astype(np.float32)
  278. sd_b = np.array([1.0]).astype(np.float32)
  279. mean_a = np.array([2.0]).astype(np.float32)
  280. sd_a = np.array([3.0]).astype(np.float32)
  281. mean_b = Tensor(mean_b, dtype=dtype.float32)
  282. sd_b = Tensor(sd_b, dtype=dtype.float32)
  283. mean_a = Tensor(mean_a, dtype=dtype.float32)
  284. sd_a = Tensor(sd_a, dtype=dtype.float32)
  285. loss = nor_net(mean_b, sd_b, mean_a, sd_a)
  286. assert isinstance(loss, Tensor)
  287. ber_net = BernoulliKlNoArgs()
  288. probs_b = Tensor([0.3], dtype=dtype.float32)
  289. probs_a = Tensor([0.7], dtype=dtype.float32)
  290. loss = ber_net(probs_b, probs_a)
  291. assert isinstance(loss, Tensor)
  292. class NormalBernoulli(nn.Cell):
  293. """
  294. Test class: basic mean/sd function.
  295. """
  296. def __init__(self):
  297. super(NormalBernoulli, self).__init__()
  298. self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32)
  299. self.b = nn.Bernoulli(0.5, dtype=dtype.int32)
  300. def construct(self):
  301. normal_mean = self.n('mean')
  302. normal_sd = self.n('sd')
  303. bernoulli_mean = self.b('mean')
  304. bernoulli_sd = self.b('sd')
  305. return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd
  306. def test_bascis():
  307. """
  308. Test mean/sd functionality of Normal and Bernoulli.
  309. """
  310. net = NormalBernoulli()
  311. normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net()
  312. assert isinstance(normal_mean, Tensor)
  313. assert isinstance(normal_sd, Tensor)
  314. assert isinstance(bernoulli_mean, Tensor)
  315. assert isinstance(bernoulli_sd, Tensor)