You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_common.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test ops """
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.ops.composite as C
  19. import mindspore.ops.functional as F
  20. import mindspore.ops.operations as P
  21. from mindspore import Tensor
  22. from mindspore.common.api import _executor
  23. grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
  24. class InputBackward(nn.Cell):
  25. """ InputBackward definition """
  26. def __init__(self, network, c1=None, c2=None):
  27. super(InputBackward, self).__init__()
  28. self.network = network
  29. self.network.set_train()
  30. self.grad = grad_all_with_sens
  31. self.c1 = c1
  32. self.c2 = c2
  33. def construct(self, *inputs):
  34. pass
  35. def construct1(self, x1, sens):
  36. return self.grad(self.network)(x1, sens)
  37. def construct2(self, x1, x2, sens):
  38. return self.grad(self.network)(x1, x2, sens)
  39. def construct3(self, x1, x2, x3, sens):
  40. return self.grad(self.network)(x1, x2, x3, sens)
  41. def construct4(self, x1, x2, x3, x4, sens):
  42. return self.grad(self.network)(x1, x2, x3, x4, sens)
  43. def construct5(self, x1, x2, x3, x4, x5, sens):
  44. return self.grad(self.network)(x1, x2, x3, x4, x5, sens)
  45. def construct6(self, x1, x2, x3, x4, x5, x6, sens):
  46. return self.grad(self.network)(x1, x2, x3, x4, x5, x6, sens)
  47. def construct7(self, x1, x2, x3, x4, x5, x6, x7, sens):
  48. return self.grad(self.network)(x1, x2, x3, x4, x5, x6, x7, sens)
  49. class InputOpNet(nn.Cell):
  50. """ InputOpNet definition """
  51. def __init__(self, op, get_first=False,
  52. c1=None, c2=None, c3=None, c4=None):
  53. super(InputOpNet, self).__init__()
  54. self.op = op
  55. self.get_first = get_first
  56. self.c1 = c1
  57. self.c2 = c2
  58. self.c3 = c3
  59. self.c4 = c4
  60. def construct(self, *inputs):
  61. pass
  62. def construct0_c0_fack(self, data):
  63. x = self.op() + data
  64. if self.get_first:
  65. x = x[0]
  66. return x
  67. def construct0_c1_fack(self, data):
  68. x = self.op(self.c1) + data
  69. if self.get_first:
  70. x = x[0]
  71. return x
  72. def construct0_c2_fack(self, data):
  73. x = self.op(self.c1, self.c2) + data
  74. if self.get_first:
  75. x = x[0]
  76. return x
  77. def construct0_c0(self):
  78. x = self.op()
  79. if self.get_first:
  80. x = x[0]
  81. return x
  82. def construct0_c1(self):
  83. x = self.op(self.c1)
  84. if self.get_first:
  85. x = x[0]
  86. return x
  87. def construct0_c2(self):
  88. x = self.op(self.c1, self.c2)
  89. if self.get_first:
  90. x = x[0]
  91. return x
  92. def construct1_c0(self, x1):
  93. x = self.op(x1)
  94. if self.get_first:
  95. x = x[0]
  96. return x
  97. def construct1_c1(self, x1):
  98. x = self.op(x1, self.c1)
  99. if self.get_first:
  100. x = x[0]
  101. return x
  102. def construct1_c2(self, x1):
  103. x = self.op(x1, self.c1, self.c2)
  104. if self.get_first:
  105. x = x[0]
  106. return x
  107. def construct1_c3(self, x1):
  108. x = self.op(x1, self.c1, self.c2, self.c3)
  109. if self.get_first:
  110. x = x[0]
  111. return x
  112. def construct1_c4(self, x1):
  113. x = self.op(x1, self.c1, self.c2, self.c3, self.c4)
  114. if self.get_first:
  115. x = x[0]
  116. return x
  117. def constructc1_1(self, x1):
  118. x = self.op(self.c1, x1)
  119. if self.get_first:
  120. x = x[0]
  121. return x
  122. def construct2_c0(self, x1, x2):
  123. x = self.op(x1, x2)
  124. if self.get_first:
  125. x = x[0]
  126. return x
  127. def construct2_c1(self, x1, x2):
  128. x = self.op(x1, x2, self.c1)
  129. if self.get_first:
  130. x = x[0]
  131. return x
  132. def construct2_c3(self, x1, x2):
  133. x = self.op(x1, x2, self.c1, self.c2, self.c3)
  134. if self.get_first:
  135. x = x[0]
  136. return x
  137. def construct3_c0(self, x1, x2, x3):
  138. x = self.op(x1, x2, x3)
  139. if self.get_first:
  140. x = x[0]
  141. return x
  142. def construct3_c1(self, x1, x2, x3):
  143. x = self.op(x1, x2, x3, self.c1)
  144. if self.get_first:
  145. x = x[0]
  146. return x
  147. def construct4_c0(self, x1, x2, x3, x4):
  148. x = self.op(x1, x2, x3, x4)
  149. if self.get_first:
  150. x = x[0]
  151. return x
  152. def construct4_c1(self, x1, x2, x3, x4):
  153. x = self.op(x1, x2, x3, x4, self.c1)
  154. if self.get_first:
  155. x = x[0]
  156. return x
  157. def construct5_c0(self, x1, x2, x3, x4, x5):
  158. x = self.op(x1, x2, x3, x4, x5)
  159. if self.get_first:
  160. x = x[0]
  161. return x
  162. def construct6_c0(self, x1, x2, x3, x4, x5, x6):
  163. x = self.op(x1, x2, x3, x4, x5, x6)
  164. if self.get_first:
  165. x = x[0]
  166. return x
  167. def construct5_c1(self, x1, x2, x3, x4, x5):
  168. x = self.op(x1, x2, x3, x4, x5, self.c1)
  169. if self.get_first:
  170. x = x[0]
  171. return x
  172. class NetOutputAsLoss(nn.Cell):
  173. """ NetOutputAsLoss definition """
  174. def __init__(self, network, output_index):
  175. super(NetOutputAsLoss, self).__init__()
  176. self.network = network
  177. self.output_index = output_index
  178. def construct(self, *inputs):
  179. pass
  180. def construct1(self, x1):
  181. predict = self.network(x1)[self.output_index]
  182. return predict
  183. def construct2(self, x1, x2):
  184. predict = self.network(x1, x2)[self.output_index]
  185. return predict
  186. def construct3(self, x1, x2, x3):
  187. predict = self.network(x1, x2, x3)[self.output_index]
  188. return predict
  189. def construct4(self, x1, x2, x3, x4):
  190. predict = self.network(x1, x2, x3, x4)[self.output_index]
  191. return predict
  192. def construct5(self, x1, x2, x3, x4, x5):
  193. predict = self.network(x1, x2, x3, x4, x5)[self.output_index]
  194. return predict
  195. def get_loss_fun(construct_net, num_input, output_index):
  196. net = NetOutputAsLoss(construct_net, output_index)
  197. f = getattr(net, 'construct%d' % num_input)
  198. setattr(net, "construct", f)
  199. return net
  200. def build_construct_graph(net, *inputs, execute=True):
  201. net.set_train()
  202. _executor.compile(net, *inputs)
  203. if execute:
  204. _executor(net, inputs)
  205. def build_backward_graph(net, output_shapes, inputs, execute=True):
  206. inputs = append_sens_to_inputs(output_shapes, inputs)
  207. net = gen_backward_net(net, len(inputs) - 1)
  208. net.set_train()
  209. _executor.compile(net, inputs)
  210. if execute:
  211. _executor(net, inputs)
  212. def convert(shp, dtype=np.float32, scale=6):
  213. if isinstance(shp, list):
  214. if not shp:
  215. return Tensor((np.random.rand() * scale).astype(dtype))
  216. return Tensor((np.random.rand(*shp) * scale).astype(dtype))
  217. return shp
  218. def gen_inputs(input_shapes, config):
  219. add_fack_input = config.get('add_fack_input', False)
  220. if not input_shapes and add_fack_input:
  221. return [Tensor(np.array([1.0]).astype(config.get('fack_input_type', np.float32)))]
  222. return [convert(shp) for shp in input_shapes]
  223. def gen_backward_inputs(input_shapes, output_shapes, config):
  224. add_fack_input = config.get('add_fack_input', False)
  225. if not input_shapes and add_fack_input:
  226. inputs = [Tensor(np.array([1.0]))]
  227. else:
  228. inputs = [convert(shp) for shp in input_shapes]
  229. sens_shape = output_shapes[0]
  230. sens = convert(sens_shape)
  231. return inputs + [sens]
  232. def append_sens_to_inputs(output_shapes, inputs):
  233. inputs = inputs
  234. sens = Tensor(np.random.normal(0, 1, output_shapes).astype(np.float32))
  235. return inputs + [sens]
  236. def gen_net(shapes, config, get_first=False):
  237. """
  238. gen_net function
  239. """
  240. add_fack_input = config.get('add_fack_input', False)
  241. op = config['op']
  242. if 'const' not in config:
  243. const_input = []
  244. else:
  245. const_input = config['const']
  246. const_first = False
  247. if 'const_first' in config:
  248. const_first = config['const_first']
  249. net = InputOpNet(op, get_first, *const_input)
  250. if const_first:
  251. fn_name = 'constructc%d_%d' % (len(const_input), len(shapes))
  252. else:
  253. fn_name = 'construct%d_c%d' % (len(shapes), len(const_input))
  254. if add_fack_input:
  255. fn_name += '_fack'
  256. f = getattr(net, fn_name)
  257. setattr(net, "construct", f)
  258. return net
  259. def gen_backward_net(construct_net, input_num):
  260. net = InputBackward(construct_net)
  261. f = getattr(net, 'construct%d' % input_num)
  262. setattr(net, "construct", f)
  263. return net
  264. def batch_tuple_tensor(data, batch_size):
  265. ret = [Tensor(np.tile(d.asnumpy(), (batch_size, 1))) for d in data]
  266. return tuple(ret)
  267. class OutPutWrap(nn.Cell):
  268. """
  269. OutPutWrap definition
  270. """
  271. def __init__(self, network, num_output, output_is_tuple):
  272. super(OutPutWrap, self).__init__()
  273. self.network = network
  274. self.num_output = num_output
  275. self.one = Tensor(np.array([1]))
  276. self.dtype = P.DType()
  277. self.cast = P.Cast()
  278. self.output_is_tuple = output_is_tuple
  279. def construct(self, *inputs):
  280. pass
  281. def construct1(self, x1):
  282. ret = F.make_tuple()
  283. predict = self.network(x1)
  284. if self.num_output == 1 and self.output_is_tuple == 0:
  285. return predict * self.cast(self.one, self.dtype(predict))
  286. for i in range(self.num_output):
  287. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  288. return ret
  289. def construct2(self, x1, x2):
  290. ret = F.make_tuple()
  291. predict = self.network(x1, x2)
  292. if self.num_output == 1 and self.output_is_tuple == 0:
  293. return predict * self.cast(self.one, self.dtype(predict))
  294. for i in range(self.num_output):
  295. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  296. return ret
  297. def construct3(self, x1, x2, x3):
  298. ret = F.make_tuple()
  299. predict = self.network(x1, x2, x3)
  300. if self.num_output == 1 and self.output_is_tuple == 0:
  301. return predict * self.cast(self.one, self.dtype(predict))
  302. for i in range(self.num_output):
  303. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  304. return ret
  305. def construct4(self, x1, x2, x3, x4):
  306. ret = F.make_tuple()
  307. predict = self.network(x1, x2, x3, x4)
  308. if self.num_output == 1 and self.output_is_tuple == 0:
  309. return predict * self.cast(self.one, self.dtype(predict))
  310. for i in range(self.num_output):
  311. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  312. return ret
  313. def construct5(self, x1, x2, x3, x4, x5):
  314. ret = F.make_tuple()
  315. predict = self.network(x1, x2, x3, x4, x5)
  316. if self.num_output == 1 and self.output_is_tuple == 0:
  317. return predict * self.cast(self.one, self.dtype(predict))
  318. for i in range(self.num_output):
  319. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  320. return ret
  321. def construct6(self, x1, x2, x3, x4, x5, x6):
  322. ret = F.make_tuple()
  323. predict = self.network(x1, x2, x3, x4, x5, x6)
  324. if self.num_output == 1 and self.output_is_tuple == 0:
  325. return predict * self.cast(self.one, self.dtype(predict))
  326. for i in range(self.num_output):
  327. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  328. return ret
  329. def get_output_wrap(network, num_input, num_output, output_is_tuple=0):
  330. net = OutPutWrap(network, num_output, output_is_tuple)
  331. f = getattr(net, 'construct%d' % num_input)
  332. setattr(net, "construct", f)
  333. return net