You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parameter.py 8.9 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.context as context
  17. import mindspore.ops.composite as C
  18. from mindspore import Tensor, Parameter
  19. from mindspore.nn import Cell
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  22. def test_parser_three_default_mixed_args_subnet():
  23. class SubNetDefaultMixedArgs(Cell):
  24. def __init__(self):
  25. super().__init__()
  26. def construct(self, y, x=3, x1=None, x2=(1, 2)):
  27. if x == 3:
  28. if x1 == None:
  29. return y
  30. return -y
  31. class NetOut(Cell):
  32. def __init__(self):
  33. super(NetOut, self).__init__()
  34. self.net_inside = SubNetDefaultMixedArgs()
  35. def construct(self, x, y=3):
  36. z = self.net_inside(x)
  37. return z
  38. tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
  39. tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
  40. net = NetOut()
  41. assert net(tensor1, tensor2) == tensor1
  42. # pylint: disable=keyword-arg-before-vararg
  43. def test_net_vararg_kwonlyarg_kwarg():
  44. class FirstNet(Cell):
  45. def __init__(self):
  46. super(FirstNet, self).__init__()
  47. self.net = SecondNet()
  48. def construct(self, x=1, z=2 + 2 + 4, y=3):
  49. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  50. return c
  51. class SecondNet(Cell):
  52. def __init__(self):
  53. super(SecondNet, self).__init__()
  54. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  55. a = x - y
  56. b = p * q
  57. c = a / b
  58. d = var[0] * var[1] * var[2] * var[3]
  59. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  60. return a + b + c + d + e
  61. net = FirstNet()
  62. net()
  63. # pylint: disable=keyword-arg-before-vararg
  64. def test_net_vararg_normal_input():
  65. class FirstNet(Cell):
  66. def __init__(self):
  67. super(FirstNet, self).__init__()
  68. self.net = SecondNet()
  69. def construct(self, x=1, z=2 + 2 + 4, y=3):
  70. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  71. return c
  72. class SecondNet(Cell):
  73. def __init__(self):
  74. super(SecondNet, self).__init__()
  75. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  76. a = x - y
  77. b = p * q
  78. c = a / b
  79. d = var[0] * var[1] * var[2] * var[3]
  80. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  81. return a + b + c + d + e
  82. x = Tensor(np.ones((2, 3, 4), np.int32))
  83. net = FirstNet()
  84. net(x, x, x)
  85. def test_prim_vararg_kwonlyarg():
  86. class FirstNet(Cell):
  87. def __init__(self):
  88. super(FirstNet, self).__init__()
  89. self.max = P.Maximum()
  90. self.min = P.Minimum()
  91. self.net = SecondNet()
  92. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  93. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  94. def construct(self):
  95. a = self.max(self.x, self.y)
  96. b = self.min(self.x, self.y)
  97. t = {"x": a, "y": b}
  98. c = self.net(t["x"], t["y"], a, b, z=a, r=b)
  99. return c
  100. class SecondNet(Cell):
  101. def __init__(self):
  102. super(SecondNet, self).__init__()
  103. self.addN = P.AddN()
  104. self.max = P.Maximum()
  105. self.add = P.TensorAdd()
  106. def construct(self, x, y, *args, z=0, r=1):
  107. c = self.max(args[0], args[1])
  108. d = self.addN(args)
  109. e = self.max(*args)
  110. ret = x + y + c + d + e + z + r
  111. return ret
  112. net = FirstNet()
  113. net()
  114. def test_no_vararg():
  115. class FirstNet(Cell):
  116. def __init__(self):
  117. super(FirstNet, self).__init__()
  118. self.max = P.Maximum()
  119. self.min = P.Minimum()
  120. self.net = SecondNet()
  121. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  122. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  123. def construct(self):
  124. t = {"x": self.x, "y": self.y}
  125. a = self.max(self.x, self.y)
  126. b = self.min(self.x, self.y)
  127. c = self.net(a, b, z=a, r=b)
  128. return c
  129. class SecondNet(Cell):
  130. def __init__(self):
  131. super(SecondNet, self).__init__()
  132. def construct(self, x, y, *, z=0, r=1):
  133. ret = x + y + z + r
  134. return ret
  135. net = FirstNet()
  136. net()
  137. def test_net_variable_and_weights():
  138. class FirstNet(Cell):
  139. def __init__(self):
  140. super(FirstNet, self).__init__()
  141. self.max = P.Maximum()
  142. self.min = P.Minimum()
  143. self.net = SecondNet()
  144. self.x = Tensor(np.ones((3, 4), np.float32))
  145. self.y = Tensor(np.ones((3, 4), np.float32))
  146. self.weight = Parameter(Tensor(np.ones((2, 3, 4)).astype(np.float32)), "w1", requires_grad=True)
  147. def construct(self, *args):
  148. t = (self.x, self.y)
  149. a = self.max(self.x, self.weight)
  150. b = self.min(self.weight, args[0])
  151. c = self.net(a, b, *t)
  152. return c
  153. class SecondNet(Cell):
  154. def __init__(self):
  155. super(SecondNet, self).__init__()
  156. self.addN = P.AddN()
  157. self.max = P.Maximum()
  158. self.add = P.TensorAdd()
  159. self.weight = Parameter(Tensor(np.ones((2, 3, 4), np.float32)), "w2", requires_grad=True)
  160. def construct(self, a, b, *args):
  161. c = self.max(args[0], a)
  162. d = self.addN(args)
  163. ret = a + b + c + d + self.weight
  164. return ret
  165. net = FirstNet()
  166. x = Tensor(np.ones((4,), np.float32))
  167. y = Tensor(np.ones((4,), np.float32))
  168. z = Tensor(np.ones((4,), np.float32))
  169. net(x, y, z)
  170. def test_net_vargs_expand():
  171. class InputBackward(Cell):
  172. """ InputBackward definition """
  173. def __init__(self, network, c1=None, c2=None):
  174. super(InputBackward, self).__init__()
  175. self.network = network
  176. self.network.set_train()
  177. self.grad = C.grad_all_with_sens
  178. self.c1 = c1
  179. self.c2 = c2
  180. def construct(self, *inputs):
  181. return self.grad(self.network)(*inputs)
  182. class AddNet(Cell):
  183. def __init__(self):
  184. super(AddNet, self).__init__()
  185. def construct(self, x, y):
  186. return x + y
  187. net = InputBackward(AddNet())
  188. x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  189. y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  190. sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  191. net.set_train()
  192. net(x, y, sens)
  193. def test_mixed_precision_const_parameter():
  194. class NetLoss(Cell):
  195. def __init__(self):
  196. super(NetLoss, self).__init__()
  197. self.shape = P.Shape()
  198. self.up_sample1 = P.ResizeBilinear((14, 14))
  199. self.up_sample2 = P.ResizeBilinear((28, 28))
  200. self.up_sample3 = P.ResizeBilinear((36, 36))
  201. def construct(self, x, y, z, *args):
  202. ret = 0
  203. if args[0] == self.shape(z)[2]:
  204. if args[0] == 14:
  205. ret = self.up_sample1(y) + x
  206. elif args[0] == 28:
  207. ret = self.up_sample2(y) - x
  208. else:
  209. ret = x / y
  210. else:
  211. ret = x * y
  212. ret = ret * z
  213. return ret
  214. class NetMain(Cell):
  215. def __init__(self, loss_fn):
  216. super(NetMain, self).__init__()
  217. self.loss_fn = loss_fn
  218. self.shape = P.Shape()
  219. def construct(self, x, y, z):
  220. size_x = self.shape(x)[2]
  221. size_y = self.shape(y)[2]
  222. ret = self.loss_fn(x, y, z, size_x, size_y)
  223. return ret
  224. loss_fn = NetLoss()
  225. net = NetMain(loss_fn)
  226. net.add_flags_recursive(fp32=True)
  227. x = Tensor(np.ones((1, 3, 28, 28), np.float32))
  228. y = Tensor(np.ones((1, 3, 14, 14), np.float32))
  229. z = Tensor(np.ones((1, 3, 28, 28), np.float32))
  230. _ = net(x, y, z)