You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_framstruct.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_framstruct """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import context
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import operations as P
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.common.parameter import Parameter, ParameterTuple
  24. from mindspore.common.initializer import initializer
  25. from mindspore.common import dtype as mstype
  26. import mindspore.nn as nn
  27. from mindspore.nn.wrap.cell_wrapper import WithGradCell, WithLossCell
  28. from ..ut_filter import non_graph_engine
  29. from ....mindspore_test_framework.utils.check_gradient import (
  30. ms_function, check_jacobian, Tensor, NNGradChecker,
  31. OperationGradChecker, check_gradient, ScalarGradChecker)
  32. from ....mindspore_test_framework.utils.bprop_util import bprop
  33. import mindspore.context as context
  34. def setup_module(module):
  35. context.set_context(mode=context.PYNATIVE_MODE)
  36. @ms_function
  37. def refactor_fac(n):
  38. """ grad_refactor_fac """
  39. if n == 0:
  40. return 1
  41. return n * refactor_fac(n-1)
  42. def test_refactor():
  43. res = refactor_fac(3)
  44. assert res == 6
  45. @ms_function
  46. def while_upper_bound(upper):
  47. rval = 2
  48. while rval < upper:
  49. rval = rval * rval
  50. return rval
  51. def test_while_upper_bound():
  52. res = while_upper_bound(10)
  53. assert res == 16
  54. @ms_function
  55. def while_lower_bound(lower):
  56. """ t_while """
  57. rval = lower
  58. while rval < 100:
  59. rval = rval * rval
  60. return rval
  61. def test_while_lower_bound():
  62. res = while_lower_bound(2)
  63. assert res == 256
  64. @ms_function
  65. def dynamic_make_tuple(x, lower, upper):
  66. out = ()
  67. i = lower
  68. while i < upper:
  69. out = out + (x,)
  70. i = i + 1
  71. return out
  72. def test_dynamic_make_tuple():
  73. # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language.
  74. with pytest.raises(RuntimeError):
  75. dynamic_make_tuple(2, 1, 5)
  76. def test_make_tuple():
  77. # Staticly recursively creating static type is valid in mindspore.
  78. @ms_function
  79. def make_tuple(x):
  80. out = ()
  81. for i in range(3):
  82. out = out + (x,)
  83. return out
  84. res = make_tuple(5)
  85. assert res == (5, 5, 5)
  86. @ms_function
  87. def add(x, y):
  88. """ add """
  89. return x + y
  90. def mul(x, y):
  91. """ mul """
  92. return x * y
  93. def add_mul(x, y):
  94. """ add_mul """
  95. return (x + y) * y
  96. def mainf(x, y):
  97. """ mainf """
  98. return C.grad_all(mul)(x, y)
  99. def grad_add_mul(x, y):
  100. """ grad_add_mul """
  101. return C.grad_all(add_mul)(x, y)
  102. @ms_function
  103. def sub(x, y):
  104. """ sub """
  105. return x - y
  106. @ms_function
  107. def if_always_true(x):
  108. """ if_always_true """
  109. if True:
  110. return x
  111. else:
  112. return 0
  113. def test_add():
  114. """ test_add """
  115. res = add(2.5, 3)
  116. assert res == 5.5
  117. def test_sub():
  118. """ test_sub """
  119. res = sub(3.5, 3)
  120. assert res == 0.5
  121. @non_graph_engine
  122. def test_if_always_true():
  123. """ test_if_always_true """
  124. res = if_always_true(1)
  125. assert res == 1
  126. @non_graph_engine
  127. def test_f():
  128. """ test_f """
  129. res = mainf(3, 2)
  130. assert res == (2, 3)
  131. @non_graph_engine
  132. def test_grad_add_mul():
  133. """ test_grad_add_mul """
  134. res = grad_add_mul(3, 2)
  135. assert res == (2, 7)
  136. def f(x):
  137. if x > 0:
  138. return f(x-1)
  139. return x
  140. @ms_function
  141. def list_subscript():
  142. """ list_subscript """
  143. x= [1, 2, 3]
  144. return x[0] * x[1]
  145. def test_list_subscript():
  146. """ test_list_subscript """
  147. res = list_subscript()
  148. assert res == 2
  149. @ms_function
  150. def ms_infer_for(xs, y):
  151. """ ms_infer_for """
  152. rval = y
  153. for x in xs:
  154. rval = rval + x
  155. return rval
  156. def test_infer_for():
  157. """ test_infer_for """
  158. t = (1, 2, 3)
  159. y = 4
  160. res = ms_infer_for(t, y)
  161. assert res == 10
  162. @ms_function
  163. def if_construct(a, b):
  164. z = a
  165. if a > b:
  166. z = a+b
  167. else:
  168. z = a*b
  169. if z > b:
  170. return z-a
  171. else:
  172. return a-b
  173. def test_if_construct():
  174. """ test_if_construct """
  175. res = if_construct(3, 6)
  176. assert res == 15
  177. @ms_function
  178. def if_scalar(a, b):
  179. """ if_abstract """
  180. if a:
  181. return a
  182. return b
  183. def test_if_scalar1():
  184. """ test_if_abstract """
  185. res = if_scalar(3, 6)
  186. assert res == 3
  187. def test_if_scalar2():
  188. """ test_if_abstract """
  189. res = if_scalar(0, 6)
  190. assert res == 6
  191. @ms_function
  192. def if_tensor(a, b):
  193. c = a
  194. if a < b:
  195. c = a+a
  196. if c < b:
  197. c = a+c
  198. else:
  199. c = a+b
  200. else:
  201. c = b+b
  202. out = c + c
  203. return out
  204. def test_if_tensor():
  205. res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
  206. assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)
  207. @ms_function
  208. def rec(x):
  209. """ rec """
  210. if x > 0:
  211. return rec(x-1)
  212. return x
  213. def test_grad_rec():
  214. """ test_grad_rec """
  215. res = C.grad(rec)(10)
  216. assert res == 1
  217. def test_me_rec():
  218. """ test_me_rec """
  219. res = rec(10)
  220. assert res == 0
  221. @ms_function
  222. def t2_while(x, y):
  223. out = y - x
  224. i = 0
  225. while i < 10:
  226. out = mul(x, y)
  227. i = i + 1
  228. return out
  229. def test_while2():
  230. res = t2_while(2, 3)
  231. assert res == 6
  232. def test_grad_while2():
  233. res = C.grad(t2_while)(2, 3)
  234. assert res == 3
  235. def if_test(a, b):
  236. """ if_test """
  237. if a > b:
  238. return 3 * a
  239. return 2 * b
  240. def grad_if(x, y):
  241. """ grad_if """
  242. return C.grad_all(if_test)(x, y)
  243. def test_grad_if():
  244. """ test_grad_if """
  245. assert grad_if(5, 4) == (3, 0)
  246. # While loop is not unrolled in forward and backward graphs.
  247. def test_dont_unroll_while():
  248. def dont_unroll_while(x, y):
  249. i = 2
  250. out = y - x
  251. while i < 10:
  252. out = mul(x, y)
  253. i = i + 1
  254. return out
  255. @ms_function()
  256. def invoke_while(x, y):
  257. return C.grad(dont_unroll_while)(x, y)
  258. res = invoke_while(2, 3)
  259. assert res == 3
  260. class ConvNet(nn.Cell):
  261. def __init__(self):
  262. super(ConvNet, self).__init__()
  263. out_channel = 16
  264. kernel_size = 3
  265. self.conv = P.Conv2D(out_channel,
  266. kernel_size,
  267. mode=1,
  268. pad_mode="pad",
  269. pad=0,
  270. stride=1,
  271. dilation=2,
  272. group=1)
  273. self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
  274. def construct(self, x):
  275. return self.conv(x, self.w)
  276. conv = ConvNet()
  277. c1 = Tensor([2], mstype.float32)
  278. c2 = Tensor([10], mstype.float32)
  279. c3 = Tensor([1], mstype.float32)
  280. @ms_function
  281. def t1_while(x, y, z):
  282. out = x
  283. i = c1
  284. while i < c2:
  285. out = out + conv(z)
  286. i = i + c3
  287. out = out + out
  288. return out
  289. def test_while_net():
  290. y = Tensor(np.ones([1,3,3,4]).astype(np.float32))
  291. x = Tensor(np.ones([1,16,12,12]).astype(np.float32))
  292. z = Tensor(np.ones([1,16,16,16]).astype(np.float32))
  293. res = t1_while(x, y, z)
  294. assert res == Tensor(np.ones([1,16,12,12]).astype(np.float32) * 2306.0)
  295. @ms_function
  296. def if_while(a, b, x, z):
  297. c = a
  298. i = c1
  299. out = x
  300. if a < b:
  301. c = a+a
  302. while i < c2:
  303. out = out + conv(z)
  304. i = i + c3
  305. else:
  306. c = b+b
  307. out = c + c
  308. return out
  309. def test_if_while():
  310. x = Tensor(np.random.randn(1,16,12,12).astype(np.float32))
  311. z = Tensor(np.random.randn(1,16,16,16).astype(np.float32))
  312. res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
  313. assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)
  314. def _while(x):
  315. """ _while """
  316. ret = x * x
  317. i = 2
  318. while i <= 3:
  319. ret = ret * i
  320. i = i + 1
  321. return ret
  322. def grad_while(x):
  323. """ grad_while """
  324. return C.grad_all(_while)(x)
  325. def test_grad_while():
  326. """ test_grad_while """
  327. assert grad_while(5) == (60,)
  328. @ms_function
  329. def fac(n):
  330. """ fac """
  331. if n == 0:
  332. return 1
  333. return n * fac(n-1)
  334. def test_fac():
  335. """ test_fac """
  336. res = fac(4)
  337. assert res == 24
  338. def _for(x):
  339. """ _for """
  340. ret = x * x
  341. for i in (2, 3):
  342. ret = ret * i
  343. return ret
  344. def grad_for(x):
  345. """ grad_for """
  346. return C.grad_all(_for)(x)
  347. def test_grad_for():
  348. """ test_grad_for """
  349. assert grad_for(5) == (60,)
  350. @ms_function
  351. def try_tail(x):
  352. """ try_tail """
  353. return C.tail(x)
  354. @non_graph_engine
  355. def test_tail():
  356. """ test_tail """
  357. try_tail((0, 1, 2, 3))
  358. @ms_function
  359. def zero_like_tensor(x):
  360. """ zero_like_tensor """
  361. return C.zeros_like(x)
  362. def test_zeros():
  363. """ test_zeros """
  364. x = Tensor(np.ones([2, 3]).astype(np.int32))
  365. res = zero_like_tensor(x)
  366. assert res == Tensor(np.zeros([2, 3]).astype(np.int32))
  367. def test_ScalarGradChecker():
  368. """ test_ScalarGradChecker """
  369. def scalar_f(x, y):
  370. return x * y
  371. check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1)
  372. def test_GradCheckerPrimitive():
  373. """ test_GradCheckerPrimitive """
  374. matmul = P.MatMul()
  375. def prim_f(x, y):
  376. return matmul(x, y)
  377. check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  378. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)),
  379. grad_checker_class=OperationGradChecker, sampling_times=2)
  380. def test_NNGradChecker():
  381. """ test_NNGradChecker """
  382. class Net(nn.Cell):
  383. """ Net definition """
  384. def __init__(self):
  385. super(Net, self).__init__()
  386. self.dense = nn.Dense(10, 10)
  387. def construct(self, x):
  388. out = self.dense(x)
  389. return out
  390. check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  391. delta=1e-3,
  392. max_error=1e-3,
  393. grad_checker_class=NNGradChecker, sampling_times=3)
  394. def test_OperationGradChecker():
  395. """ test_OperationGradChecker """
  396. class Net(nn.Cell):
  397. """ Net definition """
  398. def __init__(self):
  399. super(Net, self).__init__()
  400. self.matmul = P.MatMul()
  401. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  402. def construct(self, x, y):
  403. x = x * self.z
  404. out = self.matmul(x, y)
  405. return out
  406. check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  407. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker,
  408. input_selector=[1], sampling_times=2)
  409. def test_ScalarJacobianChecker():
  410. """ test_ScalarJacobianChecker """
  411. def scalar_f(x, y):
  412. return x * y
  413. check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0])
  414. def test_OperationJacobianChecker():
  415. """ test_OperationJacobianChecker """
  416. class Net(nn.Cell):
  417. """ Net definition """
  418. def __init__(self):
  419. super(Net, self).__init__()
  420. self.matmul = P.MatMul()
  421. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  422. def construct(self, x, y):
  423. x = x * self.z
  424. out = self.matmul(x, y)
  425. return x, out
  426. check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)),
  427. Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)),
  428. grad_checker_class=OperationGradChecker, input_selector=[0],
  429. output_selector=[0])
  430. def test_NNJacobianChecker():
  431. """ test_NNJacobianChecker """
  432. class Net(nn.Cell):
  433. """ Net definition """
  434. def __init__(self):
  435. super(Net, self).__init__()
  436. self.dense = nn.Dense(10, 10)
  437. def construct(self, x):
  438. out = self.dense(x)
  439. return out, x
  440. check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  441. delta=1e-3,
  442. max_error=1e-7,
  443. grad_checker_class=NNGradChecker,
  444. input_selector=[1],
  445. output_selector=[0])
  446. def multi_outputs(x, y):
  447. z = x + y
  448. return 2 * z, 2 * z
  449. def test_grad_multi_outputs():
  450. assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4)
  451. @ms_function
  452. def while_sp(x, y, z):
  453. out = x
  454. i = c3
  455. while i < c2:
  456. out = mul(x, out)
  457. i = i + c3
  458. return out
  459. def test_while_sp():
  460. y = Tensor(np.ones([1, 3]).astype(np.float32))
  461. z = Tensor(np.ones([1, 3]).astype(np.float32))
  462. x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
  463. res = while_sp(x, y, z)
  464. assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0)
  465. def grad_refactor_simple_1(x, y):
  466. """ add """
  467. return x * x + 2 * y
  468. def test_grad_refactor_simple_1():
  469. assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2)
  470. def grad_refactor_simple_2(x, y, z):
  471. """ add """
  472. return x * y + z + x * y * z + x + x * y
  473. def test_grad_refactor_simple_2():
  474. assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7)
  475. def grad_refactor_1(a, b):
  476. """ if_test """
  477. def inner(x, y):
  478. return x * y
  479. return inner(a, b)
  480. def test_grad_refactor_1():
  481. assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2)
  482. def grad_refactor_2(a, b):
  483. """ if_test """
  484. def inner(x):
  485. return x * b
  486. return inner(b) * inner(a)
  487. def test_grad_refactor_2():
  488. assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54)
  489. def grad_refactor_3(a):
  490. """ if_test """
  491. if a > 3:
  492. return 0
  493. return 3 * a
  494. def test_grad_refactor_3():
  495. assert C.grad_all(grad_refactor_3)(3) == (3,)
  496. def grad_refactor_4(a):
  497. """ if_test """
  498. if a > 3:
  499. return 3 * a
  500. return 0
  501. def test_grad_refactor_4():
  502. assert C.grad_all(grad_refactor_4)(4) == (3,)
  503. def grad_refactor_5(a):
  504. """ if_test """
  505. if a > 3:
  506. return 1
  507. return a
  508. def test_grad_refactor_5():
  509. assert C.grad_all(grad_refactor_5)(1) == (1,)
  510. def grad_refactor_6(a, b):
  511. """ if_test """
  512. if a > b:
  513. return 3 * a + b
  514. return 2 * b * a
  515. def test_grad_refactor_6():
  516. C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
  517. def grad_refactor_while(x):
  518. """ grad_refactor_while """
  519. rval = x
  520. while rval < 4:
  521. rval = rval * rval
  522. return rval
  523. def test_grad_refactor_9():
  524. assert C.grad_all(grad_refactor_while)(3) == (6,)
  525. def grad_refactor__while_1(x):
  526. """ _while """
  527. ret = x * x
  528. i = 2
  529. while i <= 3:
  530. ret = ret * i
  531. i = i + 1
  532. return ret
  533. def test_grad_refactor_10():
  534. """ test_grad_while """
  535. assert C.grad_all(grad_refactor__while_1)(5) == (60,)
  536. def test_grad_refactor_11():
  537. class Net(nn.Cell):
  538. """ Net definition """
  539. def __init__(self):
  540. super(Net, self).__init__()
  541. def construct(self, x, y):
  542. return x * y * y
  543. net = Net()
  544. C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))
  545. def test_grad_refactor_12():
  546. class Net(nn.Cell):
  547. """ Net definition """
  548. def __init__(self):
  549. super(Net, self).__init__()
  550. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  551. def construct(self, x, y):
  552. return x * self.z * y
  553. net = Net()
  554. C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  555. def test_grad_refactor_13():
  556. class Net(nn.Cell):
  557. """ Net definition """
  558. def __init__(self):
  559. super(Net, self).__init__()
  560. self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z')
  561. def construct(self, x, y):
  562. return x * self.z * y
  563. net = Net()
  564. weights = ParameterTuple(net.trainable_params())
  565. C.grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  566. def grad_refactor_14(a, b):
  567. """ if_test """
  568. def inner1(x):
  569. return x * b
  570. def inner2(x):
  571. return a * b
  572. def inner3(x):
  573. if (x > 2):
  574. return a
  575. return b
  576. return inner1(b) + inner2(a) + inner3(a)
  577. def test_grad_refactor_14():
  578. assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)