You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_MATRIX_CALCULATION_OPS_H
  17. #define GE_OP_MATRIX_CALCULATION_OPS_H
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  22. *@par Inputs:
  23. *Two inputs, including:
  24. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  25. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  26. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  27. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  28. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC].
  30. *@par Attributes:
  31. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  32. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  33. *@par Outputs:
  34. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  35. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  36. */
  37. REG_OP(MatMul)
  38. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  39. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  40. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  41. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  42. .ATTR(transpose_x1, Bool, false)
  43. .ATTR(transpose_x2, Bool, false)
  44. .OP_END_FACTORY_REG(MatMul)
  45. /**
  46. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c.
  47. *@par Inputs:
  48. *Five inputs, including:
  49. *@li a: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. Has format [FRACTAL_NZ].
  50. *@li b: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. When type is int8, has format [FRACTAL_Z], \n otherwise has format [FRACTAL_NZ].
  51. *@li c: A matrix Tensor. 2D or higher. Must be one of the following types: \n float16, int32, float32. When type is int32, has format [ND], \n otherwise has format [FRACTAL_NZ].
  52. *@li alpha: A 1D Tensor. The shape of alpha is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  53. *@li beta: A 1D Tensor. The shape of beta is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  54. *@par Attributes:
  55. *Two attributes, including:
  56. *@li transpose_a: Optional. A bool.\n If True, changes the shape of "a" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  57. *@li transpose_b: Optional. A bool.\n If True, changes the shape of "b" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  58. *@par Outputs:
  59. *@out: The result matrix Tensor. 4D. Must be one of the following types:\n float16, float32, int32. Has format [FRACTAL_NZ].
  60. */
  61. REG_OP(Gemm)
  62. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  63. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  64. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  65. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  66. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  67. .OUTPUT(out, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  68. .ATTR(transpose_a, Bool, false)
  69. .ATTR(transpose_b, Bool, false)
  70. .OP_END_FACTORY_REG(Gemm)
  71. /**
  72. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  73. *@par Inputs:
  74. *Three inputs, including:
  75. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  76. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  77. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  78. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  79. *@par Attributes:
  80. *@li adj_x: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  81. *@li adj_y: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M].
  82. *@par Outputs:
  83. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  84. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2".
  85. */
  86. REG_OP(BatchMatMul)
  87. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  88. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  89. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  90. .ATTR(adj_x1, Bool, false)
  91. .ATTR(adj_x2, Bool, false)
  92. .OP_END_FACTORY_REG(BatchMatMul)
  93. REG_OP(MeanCCE)
  94. .INPUT(x, TensorType::ALL())
  95. .INPUT(indices, TensorType::ALL())
  96. .OUTPUT(y, TensorType::ALL())
  97. .ATTR(keep_dims, Bool, false)
  98. .ATTR(value1, ListInt, {})
  99. .ATTR(mode, Int, 3) // 0:max pooling or 1:avg pooling
  100. .ATTR(pad_mode, Int, 0)
  101. .ATTR(global_pooling, Bool, true)
  102. .ATTR(window, ListInt, {1,1}) // kernel size
  103. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  104. .ATTR(stride, ListInt, {1,1}) // stride size
  105. .ATTR(ceil_mode, Int, 0)
  106. .ATTR(data_mode, Int, 1)
  107. .ATTR(nan_opt, Int, 0)
  108. .ATTR(fomart, Int, 0)
  109. .OP_END_FACTORY_REG(MeanCCE)
  110. REG_OP(MeanGrad)
  111. .INPUT(x, TensorType::ALL())
  112. .OUTPUT(y, TensorType::ALL())
  113. .ATTR(mode, Int, 1) // 0:max pooling or 1:avg pooling
  114. .ATTR(pad_mode, Int, 0)
  115. .ATTR(global_pooling, Bool, false)
  116. .ATTR(window, ListInt, {1,1}) // kernel size
  117. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  118. .ATTR(stride, ListInt, {1,1}) // stride size
  119. .ATTR(ceil_mode, Int, 0)
  120. .ATTR(data_mode, Int, 1)
  121. .ATTR(nan_opt, Int, 0)
  122. .ATTR(mean_grad_output_shape_value, ListInt, {1,1,1,1})
  123. .ATTR(mean_grad_output_shape_format, Int, 1) //must be NHWC
  124. .OP_END_FACTORY_REG(MeanGrad)
  125. REG_OP(MatMulCCE)
  126. .INPUT(x1, TensorType({DT_FLOAT}))
  127. .INPUT(x2, TensorType({DT_FLOAT}))
  128. .OPTIONAL_INPUT(x3, TensorType({DT_FLOAT}))
  129. .OUTPUT(y, TensorType({DT_FLOAT}))
  130. .ATTR(transpose_a, Bool, false)
  131. .ATTR(transpose_b, Bool, false)
  132. .ATTR(has_bias, Bool, false)
  133. .OP_END_FACTORY_REG(MatMulCCE)
  134. /**
  135. *@brief Computes half the L2 norm of a tensor without the sqrt.
  136. *@par Inputs:
  137. * x: A Tensor.
  138. * TensorType::FloatingDataType().
  139. *@par Outputs:
  140. *y: A Tensor. Has the same type as "x".
  141. */
  142. REG_OP(L2Loss)
  143. .INPUT(x, TensorType::FloatingDataType())
  144. .OUTPUT(y, TensorType::FloatingDataType())
  145. .OP_END_FACTORY_REG(L2Loss)
  146. /**
  147. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  148. *@par Inputs:
  149. *x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  150. *@par Outputs:
  151. *y: A Tensor. Has the same type as "x".
  152. */
  153. REG_OP(MatrixDiag)
  154. .INPUT(x, TensorType::BasicType())
  155. .OUTPUT(y, TensorType::BasicType())
  156. .OP_END_FACTORY_REG(MatrixDiag)
  157. /**
  158. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  159. *@par Inputs:
  160. * Two inputs, including:
  161. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  162. *@li assist: A Tensor of the same type as "x".
  163. *@par Outputs:
  164. *y: A Tensor. Has the same type as "x".
  165. */
  166. REG_OP(MatrixDiagD)
  167. .INPUT(x, TensorType::BasicType())
  168. .INPUT(assist, TensorType::BasicType())
  169. .OUTPUT(y, TensorType::BasicType())
  170. .OP_END_FACTORY_REG(MatrixDiagD)
  171. /**
  172. *@brief: Returns the batched diagonal part of a batched tensor.
  173. *@par Inputs:
  174. *x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  175. *@par Outputs:
  176. *y: A Tensor. Has the same type as "x".
  177. */
  178. REG_OP(MatrixDiagPart)
  179. .INPUT(x, TensorType::BasicType())
  180. .OUTPUT(y, TensorType::BasicType())
  181. .OP_END_FACTORY_REG(MatrixDiagPart)
  182. /**
  183. *@brief: Returns the batched diagonal part of a batched tensor.
  184. *@par Inputs:
  185. * Two inputs, including:
  186. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  187. *@li assist: A Tensor of the same type as "x".
  188. *@par Outputs:
  189. *y: A Tensor. Has the same type as "x".
  190. */
  191. REG_OP(MatrixDiagPartD)
  192. .INPUT(x, TensorType::BasicType())
  193. .INPUT(assist, TensorType::BasicType())
  194. .OUTPUT(y, TensorType::BasicType())
  195. .OP_END_FACTORY_REG(MatrixDiagPartD)
  196. /**
  197. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  198. *@par Inputs:
  199. * Two inputs, including:
  200. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  201. *@li diagonal: A Tensor of the same type as "x".
  202. *@par Outputs:
  203. *y: A Tensor. Has the same type as "x".
  204. */
  205. REG_OP(MatrixSetDiag)
  206. .INPUT(x, TensorType::BasicType())
  207. .INPUT(diagonal, TensorType::BasicType())
  208. .OUTPUT(y, TensorType::BasicType())
  209. .OP_END_FACTORY_REG(MatrixSetDiag)
  210. /**
  211. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  212. *@par Inputs:
  213. * Three inputs, including:
  214. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  215. *@li diagonal: A Tensor of the same type as "x".
  216. *@li assist: A Tensor of the same type as "x".
  217. *@par Outputs:
  218. *y: A Tensor. Has the same type as "x".
  219. */
  220. REG_OP(MatrixSetDiagD)
  221. .INPUT(x, TensorType::BasicType())
  222. .INPUT(diagonal, TensorType::BasicType())
  223. .INPUT(assist, TensorType::BasicType())
  224. .OUTPUT(y, TensorType::BasicType())
  225. .OP_END_FACTORY_REG(MatrixSetDiagD)
  226. /**
  227. *@brief Applies sparse "updates" to individual values or slices in a Variable.
  228. *@par Inputs:
  229. * Three inputs, including:
  230. *@li var: An ND Tensor. \n
  231. *Must be one of the following types: float16, float32, int8, uint8, bool
  232. *@li indices: An ND Tensor. \n
  233. *Must be one of the following types: int32
  234. *@li updates: An ND Tensor. \n
  235. *Must be one of the following types: float16, float32, int8, uint8, bool
  236. *@par Attributes:
  237. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  238. *@par Outputs:
  239. *var: A Tensor. Has the same type and format as input "var".
  240. */
  241. REG_OP(ScatterNdUpdate)
  242. .INPUT(var, TensorType::BasicType())
  243. .INPUT(indices, TensorType::IndexNumberType())
  244. .INPUT(updates, TensorType::BasicType())
  245. .OUTPUT(var, TensorType::BasicType())
  246. .ATTR(use_locking, Bool, false)
  247. .OP_END_FACTORY_REG(ScatterNdUpdate)
  248. /**
  249. *@brief Applies sparse addition to individual values or slices in a Variable.
  250. *@par Inputs:
  251. * Three inputs, including:
  252. *@li x: An ND Tensor. \n
  253. *Must be one of the following types: float16, float32, int32, int8, uint8
  254. *@li indices: An ND Tensor. \n
  255. *Must be one of the following types: int32
  256. *@li updates: An ND Tensor. \n
  257. *Must be one of the following types: float16, float32, int32, int8, uint8
  258. *@par Outputs:
  259. *y: A Tensor. Has the same type and format as input "x".
  260. */
  261. REG_OP(TensorScatterUpdate)
  262. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  263. .INPUT(indices, TensorType::IndexNumberType())
  264. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  265. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  266. .OP_END_FACTORY_REG(TensorScatterUpdate)
  267. /**
  268. *@brief Adds sparse "updates" to a variable reference.
  269. *@par Inputs:
  270. * Three inputs, including:
  271. *@li var: An ND Tensor. \n
  272. *Must be one of the following types: float16, float32, int32, int8, uint8
  273. *@li indices: An ND Tensor of type int32.
  274. *@li updates: An ND Tensor. \n
  275. *Must be one of the following types: float16, float32, int32, int8, uint8
  276. *@par Attributes:
  277. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  278. *@par Outputs:
  279. *var: A Tensor. Has the same type and format as input "var".
  280. */
  281. REG_OP(ScatterAdd)
  282. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  283. .INPUT(indices, TensorType::IndexNumberType())
  284. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  285. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  286. .ATTR(use_locking, Bool, false)
  287. .OP_END_FACTORY_REG(ScatterAdd)
  288. /**
  289. *@brief Divides a variable reference by sparse updates.
  290. *@par Inputs:
  291. * Three inputs, including:
  292. *@li var: An NCHW, NHWC, or ND Tensor. \n
  293. *Must be one of the following types: float16, float32, int32, int8, uint8
  294. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  295. *Must be one of the following types: int32
  296. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  297. *Must be one of the following types: float16, float32, int32, int8, uint8
  298. *@par Attributes:
  299. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  300. *@li isRef: An optional bool. Defaults to "True"
  301. *@par Outputs:
  302. *var: A Tensor. Has the same type and format as input "var".
  303. */
  304. REG_OP(ScatterDiv)
  305. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  306. .INPUT(indices, TensorType({DT_INT32}))
  307. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  308. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  309. .ATTR(use_locking, Bool, false)
  310. .OP_END_FACTORY_REG(ScatterDiv)
  311. /**
  312. *@brief Applies sparse addition to individual values or slices in a Variable.
  313. *@par Inputs:
  314. * Three inputs, including:
  315. *@li var: An ND Tensor. \n
  316. *Must be one of the following types: float16, float32, int32, int8, uint8
  317. *@li indices: An ND Tensor. \n
  318. *Must be one of the following types: int32
  319. *@li updates: An ND Tensor. \n
  320. *Must be one of the following types: float16, float32, int32, int8, uint8
  321. *@par Attributes:
  322. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  323. *@par Outputs:
  324. *var: A Tensor. Has the same type and format as input "var".
  325. */
  326. REG_OP(ScatterNdAdd)
  327. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  328. .INPUT(indices, TensorType::IndexNumberType())
  329. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  330. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  331. .ATTR(use_locking, Bool, false)
  332. .OP_END_FACTORY_REG(ScatterNdAdd)
  333. /**
  334. *@brief Applies sparse addition to individual values or slices in a Variable.
  335. *@par Inputs:
  336. * Three inputs, including:
  337. *@li x: An ND Tensor. \n
  338. *Must be one of the following types: float16, float32, int32, int8, uint8
  339. *@li indices: An ND Tensor. \n
  340. *Must be one of the following types: int32
  341. *@li updates: An ND Tensor. \n
  342. *Must be one of the following types: float16, float32, int32, int8, uint8
  343. *@par Outputs:
  344. *y: A Tensor. Has the same type and format as input "x".
  345. */
  346. REG_OP(TensorScatterAdd)
  347. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  348. .INPUT(indices, TensorType::IndexNumberType())
  349. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  350. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  351. .OP_END_FACTORY_REG(TensorScatterAdd)
  352. /**
  353. *@brief Applies sparse subtraction to individual values or slices in a Variable.
  354. *@par Inputs:
  355. * Three inputs, including:
  356. *@li var: An ND Tensor. \n
  357. *Must be one of the following types: float16, float32, int32, int8, uint8
  358. *@li indices: An ND Tensor. \n
  359. *Must be one of the following types: int32
  360. *@li updates: An ND Tensor. \n
  361. *Must be one of the following types: float16, float32, int32, int8, uint8
  362. *@par Attributes:
  363. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  364. *@par Outputs:
  365. *var: A Tensor. Has the same type and format as input "var".
  366. */
  367. REG_OP(ScatterNdSub)
  368. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  369. .INPUT(indices, TensorType::IndexNumberType())
  370. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  371. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  372. .ATTR(use_locking, Bool, false)
  373. .OP_END_FACTORY_REG(ScatterNdSub)
  374. /**
  375. *@brief Applies sparse addition to individual values or slices in a Variable.
  376. *@par Inputs:
  377. * Three inputs, including:
  378. *@li x: An ND Tensor. \n
  379. *Must be one of the following types: float16, float32, int32, int8, uint8
  380. *@li indices: An ND Tensor. \n
  381. *Must be one of the following types: int32
  382. *@li updates: An ND Tensor. \n
  383. *Must be one of the following types: float16, float32, int32, int8, uint8
  384. *@par Outputs:
  385. *y: A Tensor. Has the same type and format as input "x".
  386. */
  387. REG_OP(TensorScatterSub)
  388. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  389. .INPUT(indices, TensorType::IndexNumberType())
  390. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  391. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  392. .OP_END_FACTORY_REG(TensorScatterSub)
  393. /**
  394. *@brief Subtracts sparse updates to a variable reference.
  395. *@par Inputs:
  396. * Three inputs, including:
  397. *@li var: An ND Tensor. \n
  398. *Must be one of the following types: float16, float32, int32, int8, uint8
  399. *@li indices: An ND Tensor. \n
  400. *Must be one of the following types: int32
  401. *@li updates: An ND Tensor. \n
  402. *Must be one of the following types: float16, float32, int32, int8, uint8
  403. *@par Attributes:
  404. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  405. *@par Outputs:
  406. *var: A Tensor. Has the same type and format as input "var".
  407. */
  408. REG_OP(ScatterSub)
  409. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  410. .INPUT(indices, TensorType::IndexNumberType())
  411. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  412. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  413. .ATTR(use_locking, Bool, false)
  414. .OP_END_FACTORY_REG(ScatterSub)
  415. /**
  416. *@brief: Returns the batched diagonal part of a batched tensor with "assist".
  417. *@par Inputs:
  418. * Two inputs, including:
  419. * @li x: A Tensor of type float16, float32, or int32.
  420. * @li assist: A Tensor of the same type as "x".
  421. *@par Outputs:
  422. *y: A Tensor. Has the same type as "x".
  423. */
  424. REG_OP(DiagPartD)
  425. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  426. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  427. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  428. .OP_END_FACTORY_REG(DiagPartD)
  429. /**
  430. *@brief: Returns the batched diagonal part of a batched tensor.
  431. *@par Inputs:\n
  432. *x: A Tensor. Must be one of the following types: float16, float32, int32, int64, double, complex64, complex128.
  433. *@par Outputs:
  434. *y: A Tensor. Has the same type as "x".
  435. */
  436. REG_OP(DiagPart)
  437. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  438. DT_COMPLEX64, DT_COMPLEX128}))
  439. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  440. DT_COMPLEX64, DT_COMPLEX128}))
  441. .OP_END_FACTORY_REG(DiagPart)
  442. /**
  443. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases.
  444. *@par Inputs:
  445. * Four inputs, including:
  446. *@li x: A Tensor of type float16, int8.
  447. *@li w: A weight matrix of type float16, int8.
  448. *@li b: A Tensor of type float16, int32, float32.
  449. *@li offset_w: A Tensor of type int8.
  450. *@par Attributes:
  451. *@li num_output: Reserved.
  452. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  453. *@li axis: Reserved.
  454. *@li offset_x: Reserved.
  455. *@par Outputs:
  456. *y: The result tensor of type float16, int8, float32.
  457. *@par Quantization supported or not
  458. * Yes
  459. */
  460. REG_OP(FullyConnection)
  461. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  462. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  463. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  464. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  465. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  466. .REQUIRED_ATTR(num_output, Int)
  467. .ATTR(transpose, Bool, false)
  468. .ATTR(axis, Int, 1)
  469. .ATTR(offset_x, Int, 0)
  470. .OP_END_FACTORY_REG(FullyConnection)
  471. /**
  472. *@brief Computes the confusion matrix from predictions and labels.
  473. *@par Inputs:
  474. * Three inputs, including:
  475. *@li labels: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  476. *@li predictions: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  477. *@li weights: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  478. *@par Attributes:
  479. *@li num_classes: An integer for the shape of the output matrix. No default value.
  480. *@li dtype: Data type of the confusion matrix. No default value.
  481. *@par Outputs:
  482. *y: A Tensor. Has the same type and format as input "labels"
  483. *@attention Constraints:
  484. *@li "weights", "labels", and "predictions" are 1D tensors.
  485. *@li The output is with shape (num_classes, num_classes), where, 1 <= num_classes <= 4096.
  486. *@see Region()
  487. */
  488. REG_OP(ConfusionMatrix)
  489. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  490. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  491. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  492. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  493. .REQUIRED_ATTR(num_classes, Int)
  494. .REQUIRED_ATTR(dtype, String)
  495. .OP_END_FACTORY_REG(ConfusionMatrix)
  496. /**
  497. *@brief Multiplies sparse updates into a variable reference.
  498. *@par Inputs:
  499. * Three inputs, including:
  500. *@li var: An NCHW, NHWC, or ND Tensor. \n
  501. *Must be one of the following types: float16, float32, int32, int8, uint8
  502. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  503. *Must be one of the following types: int32
  504. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  505. *Must be one of the following types: float16, float32, int32, int8, uint8
  506. *@par Attributes:
  507. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  508. *@par Outputs:
  509. *var: A Tensor. Has the same type and format as input "var".
  510. */
  511. REG_OP(ScatterMul)
  512. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  513. .INPUT(indices, TensorType({DT_INT32}))
  514. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  515. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  516. .ATTR(use_locking, Bool, false)
  517. .OP_END_FACTORY_REG(ScatterMul)
  518. /**
  519. *@brief Reduces sparse updates into a variable reference using the "min" operation.
  520. *@par Inputs:
  521. * Three inputs, including:
  522. *@li var: An NCHW, NHWC, or ND Tensor. \n
  523. *Must be one of the following types: float16, float32, int32
  524. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  525. *Must be one of the following types: int32
  526. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  527. *Must be one of the following types: float16, float32, int32
  528. *@par Attributes:
  529. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  530. *@par Outputs:
  531. *var: A Tensor. Has the same type and format as input "var".
  532. */
  533. REG_OP(ScatterMin)
  534. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  535. .INPUT(indices, TensorType({DT_INT32}))
  536. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  537. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  538. .ATTR(use_locking, Bool, false)
  539. .OP_END_FACTORY_REG(ScatterMin)
  540. /**
  541. *@brief Reduces sparse updates into a variable reference using the "max" operation.
  542. *@par Inputs:
  543. * Three inputs, including:
  544. *@li var: An NCHW, NHWC, or ND Tensor. \n
  545. *Must be one of the following types: float16, float32, int32
  546. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  547. *Must be one of the following types: int32
  548. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  549. *Must be one of the following types: float16, float32, int32
  550. *@par Attributes:
  551. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  552. *@par Outputs:
  553. *var: A Tensor. Has the same type and format as input "var".
  554. */
  555. REG_OP(ScatterMax)
  556. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  557. .INPUT(indices, TensorType({DT_INT32}))
  558. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  559. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  560. .ATTR(use_locking, Bool, false)
  561. .OP_END_FACTORY_REG(ScatterMax)
  562. /**
  563. *@brief Applies sparse updates to a variable reference.
  564. *@par Inputs:
  565. * Three inputs, including:
  566. *@li var: An NCHW, NHWC, or ND Tensor. \n
  567. *Must be one of the following types: float16, float32, int32, int8, uint8
  568. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  569. *Must be one of the following types: int32
  570. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  571. *Must be one of the following types: float16, float32, int32, int8, uint8
  572. *@par Attributes:
  573. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  574. *@par Outputs:
  575. *var: A Tensor. Has the same type and format as input "var".
  576. */
  577. REG_OP(ScatterUpdate)
  578. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  579. .INPUT(indices, TensorType({DT_INT32}))
  580. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  581. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  582. .ATTR(use_locking, Bool, false)
  583. .OP_END_FACTORY_REG(ScatterUpdate)
  584. /**
  585. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input`.
  586. *@par Inputs:
  587. * Three inputs, including:
  588. *@li input: Rank `r` tensor where `r >= 2`. \n
  589. *@li k: \n
  590. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  591. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  592. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  593. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  594. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  595. *@par Outputs:
  596. *diagonal: The extracted diagonal(s).
  597. */
  598. REG_OP(MatrixDiagPartV2)
  599. .INPUT(input, TensorType::BasicType())
  600. .INPUT(k, TensorType({DT_INT32}))
  601. .INPUT(padding_value, TensorType::BasicType())
  602. .OUTPUT(diagonal, TensorType::BasicType())
  603. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  604. /**
  605. *@brief Returns a batched matrix tensor with new batched diagonal values.
  606. *@par Inputs:
  607. * Three inputs, including:
  608. *@li input: "Rank `r+1`, where `r >= 1`. \n
  609. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  610. *@li k:
  611. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  612. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  613. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  614. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  615. *@par Outputs:
  616. *output: Rank `r+1`, with `output.shape = input.shape`.
  617. */
  618. REG_OP(MatrixSetDiagV2)
  619. .INPUT(input, TensorType::BasicType())
  620. .INPUT(diagonal, TensorType::BasicType())
  621. .INPUT(k, TensorType({DT_INT32}))
  622. .OUTPUT(output, TensorType::BasicType())
  623. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  624. /**
  625. *@brief Returns a batched diagonal tensor with given batched diagonal values.
  626. *@par Inputs:
  627. * Five inputs, including:
  628. *@li diagonal: Rank `r`, where `r >= 1` \n
  629. *@li k:
  630. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  631. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  632. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  633. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  634. *@li num_rows:
  635. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  636. *the output matrix is a square matrix and infers the matrix size from k and the \n
  637. *innermost dimension of `diagonal`. \n
  638. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  639. *The number of columns of the output matrix. If it is not provided, the op \n
  640. *assumes the output matrix is a square matrix and infers the matrix size from \n
  641. *k and the innermost dimension of `diagonal`. \n
  642. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  643. *@par Outputs:
  644. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise.
  645. */
  646. REG_OP(MatrixDiagV2)
  647. .INPUT(diagonal, TensorType::BasicType())
  648. .INPUT(k, TensorType({DT_INT32}))
  649. .INPUT(num_rows, TensorType({DT_INT32}))
  650. .INPUT(num_cols, TensorType({DT_INT32}))
  651. .INPUT(padding_value, TensorType::BasicType())
  652. .OUTPUT(output, TensorType::BasicType())
  653. .OP_END_FACTORY_REG(MatrixDiagV2)
  654. } // namespace ge
  655. #endif // GE_OP_MATRIX_CALCULATION_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示