You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_math_ops.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test math ops """
  16. import functools
  17. import numpy as np
  18. import pytest
  19. import mindspore as ms
  20. import mindspore.context as context
  21. import mindspore.nn as nn
  22. from mindspore import Tensor
  23. from mindspore.common import dtype as mstype
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import operations as P
  26. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  27. from ..ut_filter import non_graph_engine
  28. from ....mindspore_test_framework.mindspore_test import mindspore_test
  29. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  30. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  31. from ....mindspore_test_framework.pipeline.forward.verify_exception \
  32. import pipeline_for_verify_exception_for_case_by_case_config
  33. context.set_context(mode=context.GRAPH_MODE)
  34. # pylint: disable=W0613
  35. # pylint: disable=W0231
  36. # W0613: unused-argument
  37. # W0231: super-init-not-called
  38. def test_multiply():
  39. """ test_multiply """
  40. input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))
  41. input_y = Tensor(np.array([[0.1, 0.3, -3.6], [0.4, 0.5, -3.2]]))
  42. mul = P.Mul()
  43. result = mul(input_x, input_y)
  44. expect = np.array([[-0.01, 0.09, -12.96], [0.16, 0.25, 10.24]])
  45. diff = result.asnumpy() - expect
  46. error = np.ones(shape=[2, 3]) * 1.0e-6
  47. assert np.all(diff < error)
  48. assert np.all(-diff < error)
  49. def test_sub():
  50. """ test_sub """
  51. input_x = Tensor(np.ones(shape=[3]))
  52. input_y = Tensor(np.zeros(shape=[3]))
  53. sub = P.Sub()
  54. result = sub(input_x, input_y)
  55. expect = np.ones(shape=[3])
  56. assert np.all(result.asnumpy() == expect)
  57. def test_square():
  58. """ test_square """
  59. input_tensor = Tensor(np.array([[1, 2, 3], [4, 5, 6]]))
  60. square = P.Square()
  61. result = square(input_tensor)
  62. expect = np.array([[1, 4, 9], [16, 25, 36]])
  63. assert np.all(result.asnumpy() == expect)
  64. def test_sqrt():
  65. """ test_sqrt """
  66. input_tensor = Tensor(np.array([[4, 4], [9, 9]]))
  67. sqrt = P.Sqrt()
  68. expect = np.array([[2, 2], [3, 3]])
  69. result = sqrt(input_tensor)
  70. assert np.all(result.asnumpy() == expect)
  71. class PowNet(nn.Cell):
  72. def __init__(self):
  73. super(PowNet, self).__init__()
  74. self.pow = P.Pow()
  75. def construct(self, x, y):
  76. return self.pow(x, y)
  77. def test_pow():
  78. """ test_pow """
  79. input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
  80. power = Tensor(np.array(3.0, np.int64))
  81. power2 = Tensor(np.array(True, np.bool))
  82. testpow = P.Pow()
  83. expect = np.array([[8, 8], [27, 27]])
  84. result = testpow(input_tensor, power)
  85. assert np.all(result.asnumpy() == expect)
  86. net = PowNet()
  87. net(input_tensor, True)
  88. net(input_tensor, power2)
  89. def test_exp():
  90. """ test_exp """
  91. input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
  92. testexp = P.Exp()
  93. result = testexp(input_tensor)
  94. expect = np.exp(np.array([[2, 2], [3, 3]]))
  95. assert np.all(result.asnumpy() == expect)
  96. def test_realdiv():
  97. """ test_realdiv """
  98. x = Tensor(2048.0)
  99. y = Tensor(128.0)
  100. div = P.RealDiv()
  101. result = div(x, y)
  102. x = x.asnumpy()
  103. y = y.asnumpy()
  104. expect = x / y
  105. assert np.all(result.asnumpy() == expect)
  106. def test_eye():
  107. """ test_eye """
  108. x = np.arange(3)
  109. expect = np.ones_like(x)
  110. expect = np.diag(expect)
  111. eye = P.Eye()
  112. eye_output = eye(3, 3, ms.float32)
  113. assert np.all(eye_output.asnumpy() == expect)
  114. class VirtualLossGrad(PrimitiveWithInfer):
  115. """ VirtualLossGrad definition """
  116. @prim_attr_register
  117. def __init__(self):
  118. """init VirtualLossGrad"""
  119. def __call__(self, x, out, dout):
  120. raise NotImplementedError
  121. def infer_shape(self, x_shape, out_shape, dout_shape):
  122. return x_shape
  123. def infer_dtype(self, x_dtype, out_dtype, dout_dtype):
  124. return x_dtype
  125. class VirtualLoss(PrimitiveWithInfer):
  126. """ VirtualLoss definition """
  127. @prim_attr_register
  128. def __init__(self):
  129. """init VirtualLoss"""
  130. def __call__(self, x):
  131. raise NotImplementedError
  132. def get_bprop(self):
  133. loss_grad = VirtualLossGrad()
  134. def bprop(x, out, dout):
  135. dx = loss_grad(x, out, dout)
  136. return (dx,)
  137. return bprop
  138. def infer_shape(self, x_shape):
  139. return [1]
  140. def infer_dtype(self, x_dtype):
  141. return x_dtype
  142. class NetWithLoss(nn.Cell):
  143. """ NetWithLoss definition """
  144. def __init__(self, network):
  145. super(NetWithLoss, self).__init__()
  146. self.loss = VirtualLoss()
  147. self.network = network
  148. def construct(self, x, y, b):
  149. predict = self.network(x, y, b)
  150. return self.loss(predict)
  151. class GradWrap(nn.Cell):
  152. """ GradWrap definition """
  153. def __init__(self, network):
  154. super(GradWrap, self).__init__()
  155. self.network = network
  156. def construct(self, x, y, b):
  157. return C.grad(self.network)(x, y, b)
  158. class MatMulNet(nn.Cell):
  159. """ MatMulNet definition """
  160. def __init__(self):
  161. super(MatMulNet, self).__init__()
  162. self.matmul = P.MatMul()
  163. self.biasAdd = P.BiasAdd()
  164. def construct(self, x, y, b):
  165. return self.biasAdd(self.matmul(x, y), b)
  166. class NetWithLossSub(nn.Cell):
  167. """ NetWithLossSub definition """
  168. def __init__(self, network):
  169. super(NetWithLossSub, self).__init__()
  170. self.loss = VirtualLoss()
  171. self.network = network
  172. def construct(self, x, y):
  173. predict = self.network(x, y)
  174. return self.loss(predict)
  175. class GradWrapSub(nn.Cell):
  176. """ GradWrapSub definition """
  177. def __init__(self, network):
  178. super(GradWrapSub, self).__init__()
  179. self.network = network
  180. def construct(self, x, y):
  181. return C.grad(self.network)(x, y)
  182. class SubNet(nn.Cell):
  183. """ SubNet definition """
  184. def __init__(self):
  185. super(SubNet, self).__init__()
  186. self.sub = P.Sub()
  187. def construct(self, x, y):
  188. return self.sub(x, y)
  189. class NpuFloatNet(nn.Cell):
  190. """ NpuFloat definition """
  191. def __init__(self):
  192. super(NpuFloatNet, self).__init__()
  193. self.mul = P.Mul()
  194. self.alloc_status = P.NPUAllocFloatStatus()
  195. self.get_status = P.NPUGetFloatStatus()
  196. self.clear_status = P.NPUClearFloatStatus()
  197. self.fill = P.Fill()
  198. self.shape_op = P.Shape()
  199. self.select = P.Select()
  200. self.less = P.Less()
  201. self.cast = P.Cast()
  202. self.dtype = P.DType()
  203. self.reduce_sum = P.ReduceSum(keep_dims=True)
  204. self.sub = P.Sub()
  205. self.neg = P.Neg()
  206. @C.add_flags(has_effect=True)
  207. def construct(self, x):
  208. init = self.alloc_status()
  209. self.clear_status(init)
  210. res = self.sub(x, self.neg(x))
  211. self.get_status(init)
  212. flag_sum = self.reduce_sum(init, (0,))
  213. base = self.cast(self.fill(self.dtype(res), self.shape_op(res), 0.0), self.dtype(flag_sum))
  214. cond = self.less(base, flag_sum)
  215. out = self.select(cond, self.cast(base, self.dtype(res)), res)
  216. return out
  217. class DiagNet(nn.Cell):
  218. """ DiagNet definition """
  219. def __init__(self):
  220. super(DiagNet, self).__init__()
  221. self.fill = P.Fill()
  222. self.diag = P.Diag()
  223. def construct(self, x):
  224. return x - self.diag(self.fill(mstype.float32, (3,), 1.0))
  225. class NetWithLossCumSum(nn.Cell):
  226. """ NetWithLossCumSum definition """
  227. def __init__(self, network):
  228. super(NetWithLossCumSum, self).__init__()
  229. self.loss = VirtualLoss()
  230. self.network = network
  231. def construct(self, input_):
  232. predict = self.network(input_)
  233. return self.loss(predict)
  234. class GradWrapCumSum(nn.Cell):
  235. """ GradWrap definition """
  236. def __init__(self, network):
  237. super(GradWrapCumSum, self).__init__()
  238. self.network = network
  239. def construct(self, input_):
  240. return C.grad(self.network)(input_)
  241. class NetCumSum(nn.Cell):
  242. """ NetCumSum definition """
  243. def __init__(self):
  244. super(NetCumSum, self).__init__()
  245. self.cumsum = P.CumSum()
  246. self.axis = 1
  247. def construct(self, input_):
  248. return self.cumsum(input_, self.axis)
  249. class SignNet(nn.Cell):
  250. def __init__(self):
  251. super(SignNet, self).__init__()
  252. self.sign = P.Sign()
  253. def construct(self, x):
  254. return self.sign(x)
  255. class AssignAdd(nn.Cell):
  256. def __init__(self):
  257. super().__init__()
  258. self.op = P.AssignAdd()
  259. self.inputdata = Parameter(initializer(1, [1], ms.float32), name="global_step")
  260. def construct(self, input_):
  261. self.inputdata = input_
  262. return self.op(self.inputdata, input_)
  263. class FloorNet(nn.Cell):
  264. def __init__(self):
  265. super(FloorNet, self).__init__()
  266. self.floor = P.Floor()
  267. def construct(self, x):
  268. return self.floor(x)
  269. class Log1pNet(nn.Cell):
  270. def __init__(self):
  271. super(Log1pNet, self).__init__()
  272. self.log1p = P.Log1p()
  273. def construct(self, x):
  274. return self.log1p(x)
  275. class ErfcNet(nn.Cell):
  276. def __init__(self):
  277. super(ErfcNet, self).__init__()
  278. self.erfc = P.Erfc()
  279. def construct(self, x):
  280. return self.erfc(x)
  281. test_case_math_ops = [
  282. ('MatMulGrad', {
  283. 'block': GradWrap(NetWithLoss(MatMulNet())),
  284. 'desc_inputs': [Tensor(np.ones([3, 3]).astype(np.int32)),
  285. Tensor(np.ones([3, 3]).astype(np.int32)),
  286. Tensor(np.ones([3]).astype(np.int32))],
  287. 'desc_bprop': [Tensor(np.ones([3, 3]).astype(np.int32)),
  288. Tensor(np.ones([3, 3]).astype(np.int32)),
  289. Tensor(np.ones([3]).astype(np.int32))],
  290. 'skip': ['backward']}),
  291. ('CumSumGrad', {
  292. 'block': GradWrapCumSum(NetWithLossCumSum(NetCumSum())),
  293. 'desc_inputs': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))],
  294. 'desc_bprop': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))],
  295. 'skip': ['backward']}),
  296. ('Diag', {
  297. 'block': DiagNet(),
  298. 'desc_inputs': [Tensor(np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]], np.float32))],
  299. 'desc_bprop': [Tensor(np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]], np.float32))],
  300. 'skip': ['backward']}),
  301. ('SubBroadcast', {
  302. 'block': GradWrapSub(NetWithLossSub(SubNet())),
  303. 'desc_inputs': [Tensor(np.ones([5, 3])), Tensor(np.ones([8, 5, 3]))],
  304. 'desc_bprop': [Tensor(np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]], np.float32))],
  305. 'skip': ['backward']}),
  306. ('NpuFloat_NotOverflow', {
  307. 'block': NpuFloatNet(),
  308. 'desc_inputs': [Tensor(np.full((8, 5, 3, 1), 655, dtype=np.float16), dtype=ms.float16)],
  309. 'desc_bprop': [Tensor(np.full((8, 5, 3, 1), 655, dtype=np.float16), dtype=ms.float16)],
  310. 'skip': ['backward']}),
  311. ('NpuFloat_Overflow', {
  312. 'block': NpuFloatNet(),
  313. 'desc_inputs': [Tensor(np.full((8, 5, 3, 1), 65504, dtype=np.float16), dtype=ms.float16)],
  314. 'desc_bprop': [Tensor(np.full((8, 5, 3, 1), 65504, dtype=np.float16), dtype=ms.float16)],
  315. 'skip': ['backward']}),
  316. ('Sign', {
  317. 'block': SignNet(),
  318. 'desc_inputs': [Tensor(np.array([[1., 0., -2.]], np.float32))],
  319. 'desc_bprop': [Tensor(np.array([[1., 0., -2.]], np.float32))],
  320. 'skip': ['backward']}),
  321. ('Floor', {
  322. 'block': FloorNet(),
  323. 'desc_inputs': [Tensor(np.array([[1., 0., -2.]], np.float32))],
  324. 'desc_bprop': [Tensor(np.array([[1., 0., -2.]], np.float32))],
  325. 'skip': ['backward']}),
  326. ('Log1p', {
  327. 'block': Log1pNet(),
  328. 'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
  329. 'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
  330. 'skip': ['backward']}),
  331. ('Erfc', {
  332. 'block': ErfcNet(),
  333. 'desc_inputs': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
  334. 'desc_bprop': [Tensor(np.array([[1.0, 2.0, 4.0]], np.float32))],
  335. }),
  336. ]
  337. test_case_lists = [test_case_math_ops]
  338. test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
  339. # use -k to select certain testcast
  340. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  341. @non_graph_engine
  342. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  343. def test_exec():
  344. context.set_context(mode=context.GRAPH_MODE)
  345. return test_exec_case
  346. raise_set = [
  347. ('StridedSlice_1_Error', {
  348. 'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': TypeError}),
  349. 'desc_inputs': [0]}),
  350. ('StridedSlice_2_Error', {
  351. 'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': TypeError}),
  352. 'desc_inputs': [0]}),
  353. ('StridedSlice_3_Error', {
  354. 'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': TypeError}),
  355. 'desc_inputs': [0]}),
  356. ('StridedSlice_4_Error', {
  357. 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
  358. 'desc_inputs': [0]}),
  359. ('AssignAdd_Error', {
  360. 'block': (P.AssignAdd(), {'exception': IndexError}),
  361. 'desc_inputs': [[1]]}),
  362. ]
  363. @mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
  364. def test_check_exception():
  365. return raise_set