You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_common.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test ops """
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.ops.composite as C
  19. import mindspore.ops.functional as F
  20. import mindspore.ops.operations as P
  21. from mindspore import Tensor
  22. from mindspore.common.api import _executor
  23. class InputBackward(nn.Cell):
  24. """ InputBackward definition """
  25. def __init__(self, network, c1=None, c2=None):
  26. super(InputBackward, self).__init__()
  27. self.network = network
  28. self.network.set_train()
  29. self.grad = C.grad_all_with_sens
  30. self.c1 = c1
  31. self.c2 = c2
  32. def construct(self, *inputs):
  33. pass
  34. def construct1(self, x1, sens):
  35. return self.grad(self.network)(x1, sens)
  36. def construct2(self, x1, x2, sens):
  37. return self.grad(self.network)(x1, x2, sens)
  38. def construct3(self, x1, x2, x3, sens):
  39. return self.grad(self.network)(x1, x2, x3, sens)
  40. def construct4(self, x1, x2, x3, x4, sens):
  41. return self.grad(self.network)(x1, x2, x3, x4, sens)
  42. def construct5(self, x1, x2, x3, x4, x5, sens):
  43. return self.grad(self.network)(x1, x2, x3, x4, x5, sens)
  44. def construct6(self, x1, x2, x3, x4, x5, x6, sens):
  45. return self.grad(self.network)(x1, x2, x3, x4, x5, x6, sens)
  46. def construct7(self, x1, x2, x3, x4, x5, x6, x7, sens):
  47. return self.grad(self.network)(x1, x2, x3, x4, x5, x6, x7, sens)
  48. class InputOpNet(nn.Cell):
  49. """ InputOpNet definition """
  50. def __init__(self, op, get_first=False,
  51. c1=None, c2=None, c3=None, c4=None):
  52. super(InputOpNet, self).__init__()
  53. self.op = op
  54. self.get_first = get_first
  55. self.c1 = c1
  56. self.c2 = c2
  57. self.c3 = c3
  58. self.c4 = c4
  59. def construct(self, *inputs):
  60. pass
  61. def construct0_c0_fack(self, data):
  62. x = self.op() + data
  63. if self.get_first:
  64. x = x[0]
  65. return x
  66. def construct0_c1_fack(self, data):
  67. x = self.op(self.c1) + data
  68. if self.get_first:
  69. x = x[0]
  70. return x
  71. def construct0_c2_fack(self, data):
  72. x = self.op(self.c1, self.c2) + data
  73. if self.get_first:
  74. x = x[0]
  75. return x
  76. def construct0_c0(self):
  77. x = self.op()
  78. if self.get_first:
  79. x = x[0]
  80. return x
  81. def construct0_c1(self):
  82. x = self.op(self.c1)
  83. if self.get_first:
  84. x = x[0]
  85. return x
  86. def construct0_c2(self):
  87. x = self.op(self.c1, self.c2)
  88. if self.get_first:
  89. x = x[0]
  90. return x
  91. def construct1_c0(self, x1):
  92. x = self.op(x1)
  93. if self.get_first:
  94. x = x[0]
  95. return x
  96. def construct1_c1(self, x1):
  97. x = self.op(x1, self.c1)
  98. if self.get_first:
  99. x = x[0]
  100. return x
  101. def construct1_c2(self, x1):
  102. x = self.op(x1, self.c1, self.c2)
  103. if self.get_first:
  104. x = x[0]
  105. return x
  106. def construct1_c3(self, x1):
  107. x = self.op(x1, self.c1, self.c2, self.c3)
  108. if self.get_first:
  109. x = x[0]
  110. return x
  111. def construct1_c4(self, x1):
  112. x = self.op(x1, self.c1, self.c2, self.c3, self.c4)
  113. if self.get_first:
  114. x = x[0]
  115. return x
  116. def constructc1_1(self, x1):
  117. x = self.op(self.c1, x1)
  118. if self.get_first:
  119. x = x[0]
  120. return x
  121. def construct2_c0(self, x1, x2):
  122. x = self.op(x1, x2)
  123. if self.get_first:
  124. x = x[0]
  125. return x
  126. def construct2_c1(self, x1, x2):
  127. x = self.op(x1, x2, self.c1)
  128. if self.get_first:
  129. x = x[0]
  130. return x
  131. def construct2_c3(self, x1, x2):
  132. x = self.op(x1, x2, self.c1, self.c2, self.c3)
  133. if self.get_first:
  134. x = x[0]
  135. return x
  136. def construct3_c0(self, x1, x2, x3):
  137. x = self.op(x1, x2, x3)
  138. if self.get_first:
  139. x = x[0]
  140. return x
  141. def construct3_c1(self, x1, x2, x3):
  142. x = self.op(x1, x2, x3, self.c1)
  143. if self.get_first:
  144. x = x[0]
  145. return x
  146. def construct4_c0(self, x1, x2, x3, x4):
  147. x = self.op(x1, x2, x3, x4)
  148. if self.get_first:
  149. x = x[0]
  150. return x
  151. def construct4_c1(self, x1, x2, x3, x4):
  152. x = self.op(x1, x2, x3, x4, self.c1)
  153. if self.get_first:
  154. x = x[0]
  155. return x
  156. def construct5_c0(self, x1, x2, x3, x4, x5):
  157. x = self.op(x1, x2, x3, x4, x5)
  158. if self.get_first:
  159. x = x[0]
  160. return x
  161. def construct6_c0(self, x1, x2, x3, x4, x5, x6):
  162. x = self.op(x1, x2, x3, x4, x5, x6)
  163. if self.get_first:
  164. x = x[0]
  165. return x
  166. def construct5_c1(self, x1, x2, x3, x4, x5):
  167. x = self.op(x1, x2, x3, x4, x5, self.c1)
  168. if self.get_first:
  169. x = x[0]
  170. return x
  171. class NetOutputAsLoss(nn.Cell):
  172. """ NetOutputAsLoss definition """
  173. def __init__(self, network, output_index):
  174. super(NetOutputAsLoss, self).__init__()
  175. self.network = network
  176. self.output_index = output_index
  177. def construct(self, *inputs):
  178. pass
  179. def construct1(self, x1):
  180. predict = self.network(x1)[self.output_index]
  181. return predict
  182. def construct2(self, x1, x2):
  183. predict = self.network(x1, x2)[self.output_index]
  184. return predict
  185. def construct3(self, x1, x2, x3):
  186. predict = self.network(x1, x2, x3)[self.output_index]
  187. return predict
  188. def construct4(self, x1, x2, x3, x4):
  189. predict = self.network(x1, x2, x3, x4)[self.output_index]
  190. return predict
  191. def construct5(self, x1, x2, x3, x4, x5):
  192. predict = self.network(x1, x2, x3, x4, x5)[self.output_index]
  193. return predict
  194. def get_loss_fun(construct_net, num_input, output_index):
  195. net = NetOutputAsLoss(construct_net, output_index)
  196. f = getattr(net, 'construct%d' % num_input)
  197. setattr(net, "construct", f)
  198. return net
  199. def build_construct_graph(net, *inputs, execute=True):
  200. net.set_train()
  201. _executor.compile(net, *inputs)
  202. if execute:
  203. _executor(net, inputs)
  204. def build_backward_graph(net, output_shapes, inputs, execute=True):
  205. inputs = append_sens_to_inputs(output_shapes, inputs)
  206. net = gen_backward_net(net, len(inputs) - 1)
  207. net.set_train()
  208. _executor.compile(net, inputs)
  209. if execute:
  210. _executor(net, inputs)
  211. def convert(shp, dtype=np.float32, scale=6):
  212. if isinstance(shp, list):
  213. if not shp:
  214. return Tensor((np.random.rand() * scale).astype(dtype))
  215. return Tensor((np.random.rand(*shp) * scale).astype(dtype))
  216. return shp
  217. def gen_inputs(input_shapes, config):
  218. add_fack_input = config.get('add_fack_input', False)
  219. if not input_shapes and add_fack_input:
  220. return [Tensor(np.array([1.0]).astype(config.get('fack_input_type', np.float32)))]
  221. return [convert(shp) for shp in input_shapes]
  222. def gen_backward_inputs(input_shapes, output_shapes, config):
  223. add_fack_input = config.get('add_fack_input', False)
  224. if not input_shapes and add_fack_input:
  225. inputs = [Tensor(np.array([1.0]))]
  226. else:
  227. inputs = [convert(shp) for shp in input_shapes]
  228. sens_shape = output_shapes[0]
  229. sens = convert(sens_shape)
  230. return inputs + [sens]
  231. def append_sens_to_inputs(output_shapes, inputs):
  232. inputs = inputs
  233. sens = Tensor(np.random.normal(0, 1, output_shapes).astype(np.float32))
  234. return inputs + [sens]
  235. def gen_net(shapes, config, get_first=False):
  236. """
  237. gen_net function
  238. """
  239. add_fack_input = config.get('add_fack_input', False)
  240. op = config['op']
  241. if 'const' not in config:
  242. const_input = []
  243. else:
  244. const_input = config['const']
  245. const_first = False
  246. if 'const_first' in config:
  247. const_first = config['const_first']
  248. net = InputOpNet(op, get_first, *const_input)
  249. if const_first:
  250. fn_name = 'constructc%d_%d' % (len(const_input), len(shapes))
  251. else:
  252. fn_name = 'construct%d_c%d' % (len(shapes), len(const_input))
  253. if add_fack_input:
  254. fn_name += '_fack'
  255. f = getattr(net, fn_name)
  256. setattr(net, "construct", f)
  257. return net
  258. def gen_backward_net(construct_net, input_num):
  259. net = InputBackward(construct_net)
  260. f = getattr(net, 'construct%d' % input_num)
  261. setattr(net, "construct", f)
  262. return net
  263. def batch_tuple_tensor(data, batch_size):
  264. ret = [Tensor(np.tile(d.asnumpy(), (batch_size, 1))) for d in data]
  265. return tuple(ret)
  266. class OutPutWrap(nn.Cell):
  267. """
  268. OutPutWrap definition
  269. """
  270. def __init__(self, network, num_output, output_is_tuple):
  271. super(OutPutWrap, self).__init__()
  272. self.network = network
  273. self.num_output = num_output
  274. self.one = Tensor(np.array([1]))
  275. self.dtype = P.DType()
  276. self.cast = P.Cast()
  277. self.output_is_tuple = output_is_tuple
  278. def construct(self, *inputs):
  279. pass
  280. def construct1(self, x1):
  281. ret = F.make_tuple()
  282. predict = self.network(x1)
  283. if self.num_output == 1 and self.output_is_tuple == 0:
  284. return predict * self.cast(self.one, self.dtype(predict))
  285. for i in range(self.num_output):
  286. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  287. return ret
  288. def construct2(self, x1, x2):
  289. ret = F.make_tuple()
  290. predict = self.network(x1, x2)
  291. if self.num_output == 1 and self.output_is_tuple == 0:
  292. return predict * self.cast(self.one, self.dtype(predict))
  293. for i in range(self.num_output):
  294. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  295. return ret
  296. def construct3(self, x1, x2, x3):
  297. ret = F.make_tuple()
  298. predict = self.network(x1, x2, x3)
  299. if self.num_output == 1 and self.output_is_tuple == 0:
  300. return predict * self.cast(self.one, self.dtype(predict))
  301. for i in range(self.num_output):
  302. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  303. return ret
  304. def construct4(self, x1, x2, x3, x4):
  305. ret = F.make_tuple()
  306. predict = self.network(x1, x2, x3, x4)
  307. if self.num_output == 1 and self.output_is_tuple == 0:
  308. return predict * self.cast(self.one, self.dtype(predict))
  309. for i in range(self.num_output):
  310. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  311. return ret
  312. def construct5(self, x1, x2, x3, x4, x5):
  313. ret = F.make_tuple()
  314. predict = self.network(x1, x2, x3, x4, x5)
  315. if self.num_output == 1 and self.output_is_tuple == 0:
  316. return predict * self.cast(self.one, self.dtype(predict))
  317. for i in range(self.num_output):
  318. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  319. return ret
  320. def construct6(self, x1, x2, x3, x4, x5, x6):
  321. ret = F.make_tuple()
  322. predict = self.network(x1, x2, x3, x4, x5, x6)
  323. if self.num_output == 1 and self.output_is_tuple == 0:
  324. return predict * self.cast(self.one, self.dtype(predict))
  325. for i in range(self.num_output):
  326. ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i])))
  327. return ret
  328. def get_output_wrap(network, num_input, num_output, output_is_tuple=0):
  329. net = OutPutWrap(network, num_output, output_is_tuple)
  330. f = getattr(net, 'construct%d' % num_input)
  331. setattr(net, "construct", f)
  332. return net