You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_MATRIX_CALCULATION_OPS_H
  17. #define GE_OP_MATRIX_CALCULATION_OPS_H
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  22. *@par Inputs:
  23. *Two inputs, including:
  24. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  25. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  26. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  27. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  28. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC].
  30. *@par Attributes:
  31. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  32. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  33. *@par Outputs:
  34. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  35. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  36. */
  37. REG_OP(MatMul)
  38. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  39. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  40. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  41. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  42. .ATTR(transpose_x1, Bool, false)
  43. .ATTR(transpose_x2, Bool, false)
  44. .OP_END_FACTORY_REG(MatMul)
  45. /**
  46. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  47. *@par Inputs:
  48. *Two inputs, including:
  49. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  50. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  51. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  52. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  53. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  54. * float32, int32. Has format [ND, NHWC].
  55. *@par Attributes:
  56. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  57. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  58. *@par Outputs:
  59. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  60. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  61. */
  62. REG_OP(MatMulV2)
  63. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  64. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  65. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  66. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  67. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  68. .ATTR(transpose_x1, Bool, false)
  69. .ATTR(transpose_x2, Bool, false)
  70. .ATTR(offset_x, Int, 0)
  71. .OP_END_FACTORY_REG(MatMulV2)
  72. /**
  73. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c.
  74. *@par Inputs:
  75. *Five inputs, including:
  76. *@li a: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. Has format [FRACTAL_NZ].
  77. *@li b: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. When type is int8, has format [FRACTAL_Z], \n otherwise has format [FRACTAL_NZ].
  78. *@li c: A matrix Tensor. 2D or higher. Must be one of the following types: \n float16, int32, float32. When type is int32, has format [ND], \n otherwise has format [FRACTAL_NZ].
  79. *@li alpha: A 1D Tensor. The shape of alpha is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  80. *@li beta: A 1D Tensor. The shape of beta is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  81. *@par Attributes:
  82. *Two attributes, including:
  83. *@li transpose_a: Optional. A bool.\n If True, changes the shape of "a" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  84. *@li transpose_b: Optional. A bool.\n If True, changes the shape of "b" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  85. *@par Outputs:
  86. *@out: The result matrix Tensor. 4D. Must be one of the following types:\n float16, float32, int32. Has format [FRACTAL_NZ].
  87. */
  88. REG_OP(Gemm)
  89. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  90. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  91. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  92. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  93. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  94. .OUTPUT(out, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  95. .ATTR(transpose_a, Bool, false)
  96. .ATTR(transpose_b, Bool, false)
  97. .OP_END_FACTORY_REG(Gemm)
  98. /**
  99. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  100. *@par Inputs:
  101. *Three inputs, including:
  102. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  103. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  104. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  105. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  106. *@par Attributes:
  107. *@li adj_x: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  108. *@li adj_y: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M].
  109. *@par Outputs:
  110. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  111. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2".
  112. */
  113. REG_OP(BatchMatMul)
  114. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  115. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  116. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  117. .ATTR(adj_x1, Bool, false)
  118. .ATTR(adj_x2, Bool, false)
  119. .OP_END_FACTORY_REG(BatchMatMul)
  120. REG_OP(MeanCCE)
  121. .INPUT(x, TensorType::ALL())
  122. .INPUT(indices, TensorType::ALL())
  123. .OUTPUT(y, TensorType::ALL())
  124. .ATTR(keep_dims, Bool, false)
  125. .ATTR(value1, ListInt, {})
  126. .ATTR(mode, Int, 3) // 0:max pooling or 1:avg pooling
  127. .ATTR(pad_mode, Int, 0)
  128. .ATTR(global_pooling, Bool, true)
  129. .ATTR(window, ListInt, {1,1}) // kernel size
  130. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  131. .ATTR(stride, ListInt, {1,1}) // stride size
  132. .ATTR(ceil_mode, Int, 0)
  133. .ATTR(data_mode, Int, 1)
  134. .ATTR(nan_opt, Int, 0)
  135. .ATTR(fomart, Int, 0)
  136. .OP_END_FACTORY_REG(MeanCCE)
  137. REG_OP(MeanGrad)
  138. .INPUT(x, TensorType::ALL())
  139. .OUTPUT(y, TensorType::ALL())
  140. .ATTR(mode, Int, 1) // 0:max pooling or 1:avg pooling
  141. .ATTR(pad_mode, Int, 0)
  142. .ATTR(global_pooling, Bool, false)
  143. .ATTR(window, ListInt, {1,1}) // kernel size
  144. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  145. .ATTR(stride, ListInt, {1,1}) // stride size
  146. .ATTR(ceil_mode, Int, 0)
  147. .ATTR(data_mode, Int, 1)
  148. .ATTR(nan_opt, Int, 0)
  149. .ATTR(mean_grad_output_shape_value, ListInt, {1,1,1,1})
  150. .ATTR(mean_grad_output_shape_format, Int, 1) //must be NHWC
  151. .OP_END_FACTORY_REG(MeanGrad)
  152. REG_OP(MatMulCCE)
  153. .INPUT(x1, TensorType({DT_FLOAT}))
  154. .INPUT(x2, TensorType({DT_FLOAT}))
  155. .OPTIONAL_INPUT(x3, TensorType({DT_FLOAT}))
  156. .OUTPUT(y, TensorType({DT_FLOAT}))
  157. .ATTR(transpose_a, Bool, false)
  158. .ATTR(transpose_b, Bool, false)
  159. .ATTR(has_bias, Bool, false)
  160. .OP_END_FACTORY_REG(MatMulCCE)
  161. /**
  162. *@brief Computes half the L2 norm of a tensor without the sqrt.
  163. *@par Inputs:
  164. * x: A Tensor.
  165. * TensorType::FloatingDataType().
  166. *@par Outputs:
  167. *y: A Tensor. Has the same type as "x".
  168. */
  169. REG_OP(L2Loss)
  170. .INPUT(x, TensorType::FloatingDataType())
  171. .OUTPUT(y, TensorType::FloatingDataType())
  172. .OP_END_FACTORY_REG(L2Loss)
  173. /**
  174. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  175. *@par Inputs:
  176. *x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  177. *@par Outputs:
  178. *y: A Tensor. Has the same type as "x".
  179. */
  180. REG_OP(MatrixDiag)
  181. .INPUT(x, TensorType::BasicType())
  182. .OUTPUT(y, TensorType::BasicType())
  183. .OP_END_FACTORY_REG(MatrixDiag)
  184. /**
  185. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  186. *@par Inputs:
  187. * Two inputs, including:
  188. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  189. *@li assist: A Tensor of the same type as "x".
  190. *@par Outputs:
  191. *y: A Tensor. Has the same type as "x".
  192. */
  193. REG_OP(MatrixDiagD)
  194. .INPUT(x, TensorType::BasicType())
  195. .INPUT(assist, TensorType::BasicType())
  196. .OUTPUT(y, TensorType::BasicType())
  197. .OP_END_FACTORY_REG(MatrixDiagD)
  198. /**
  199. *@brief: Returns the batched diagonal part of a batched tensor.
  200. *@par Inputs:
  201. *x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  202. *@par Outputs:
  203. *y: A Tensor. Has the same type as "x".
  204. */
  205. REG_OP(MatrixDiagPart)
  206. .INPUT(x, TensorType::BasicType())
  207. .OUTPUT(y, TensorType::BasicType())
  208. .OP_END_FACTORY_REG(MatrixDiagPart)
  209. /**
  210. *@brief: Returns the batched diagonal part of a batched tensor.
  211. *@par Inputs:
  212. * Two inputs, including:
  213. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  214. *@li assist: A Tensor of the same type as "x".
  215. *@par Outputs:
  216. *y: A Tensor. Has the same type as "x".
  217. */
  218. REG_OP(MatrixDiagPartD)
  219. .INPUT(x, TensorType::BasicType())
  220. .INPUT(assist, TensorType::BasicType())
  221. .OUTPUT(y, TensorType::BasicType())
  222. .OP_END_FACTORY_REG(MatrixDiagPartD)
  223. /**
  224. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  225. *@par Inputs:
  226. * Two inputs, including:
  227. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  228. *@li diagonal: A Tensor of the same type as "x".
  229. *@par Outputs:
  230. *y: A Tensor. Has the same type as "x".
  231. */
  232. REG_OP(MatrixSetDiag)
  233. .INPUT(x, TensorType::BasicType())
  234. .INPUT(diagonal, TensorType::BasicType())
  235. .OUTPUT(y, TensorType::BasicType())
  236. .OP_END_FACTORY_REG(MatrixSetDiag)
  237. /**
  238. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  239. *@par Inputs:
  240. * Three inputs, including:
  241. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  242. *@li diagonal: A Tensor of the same type as "x".
  243. *@li assist: A Tensor of the same type as "x".
  244. *@par Outputs:
  245. *y: A Tensor. Has the same type as "x".
  246. */
  247. REG_OP(MatrixSetDiagD)
  248. .INPUT(x, TensorType::BasicType())
  249. .INPUT(diagonal, TensorType::BasicType())
  250. .INPUT(assist, TensorType::BasicType())
  251. .OUTPUT(y, TensorType::BasicType())
  252. .OP_END_FACTORY_REG(MatrixSetDiagD)
  253. /**
  254. *@brief Applies sparse "updates" to individual values or slices in a Variable.
  255. *@par Inputs:
  256. * Three inputs, including:
  257. *@li var: An ND Tensor. \n
  258. *Must be one of the following types: float16, float32, int8, uint8, bool
  259. *@li indices: An ND Tensor. \n
  260. *Must be one of the following types: int32
  261. *@li updates: An ND Tensor. \n
  262. *Must be one of the following types: float16, float32, int8, uint8, bool
  263. *@par Attributes:
  264. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  265. *@par Outputs:
  266. *var: A Tensor. Has the same type and format as input "var".
  267. */
  268. REG_OP(ScatterNdUpdate)
  269. .INPUT(var, TensorType::BasicType())
  270. .INPUT(indices, TensorType::IndexNumberType())
  271. .INPUT(updates, TensorType::BasicType())
  272. .OUTPUT(var, TensorType::BasicType())
  273. .ATTR(use_locking, Bool, false)
  274. .OP_END_FACTORY_REG(ScatterNdUpdate)
  275. /**
  276. *@brief Applies sparse addition to individual values or slices in a Variable.
  277. *@par Inputs:
  278. * Three inputs, including:
  279. *@li x: An ND Tensor. \n
  280. *Must be one of the following types: float16, float32, bool, int8, uint8
  281. *@li indices: An ND Tensor. \n
  282. *Must be one of the following types: int32
  283. *@li updates: An ND Tensor. \n
  284. *Must be one of the following types: float16, float32, bool, int8, uint8
  285. *@par Outputs:
  286. *y: A Tensor. Has the same type and format as input "x".
  287. */
  288. REG_OP(TensorScatterUpdate)
  289. .INPUT(x, TensorType::BasicType())
  290. .INPUT(indices, TensorType::IndexNumberType())
  291. .INPUT(updates, TensorType::BasicType())
  292. .OUTPUT(y, TensorType::BasicType())
  293. .OP_END_FACTORY_REG(TensorScatterUpdate)
  294. /**
  295. *@brief Adds sparse "updates" to a variable reference.
  296. *@par Inputs:
  297. * Three inputs, including:
  298. *@li var: An ND Tensor. \n
  299. *Must be one of the following types: float16, float32, int32, int8, uint8
  300. *@li indices: An ND Tensor of type int32.
  301. *@li updates: An ND Tensor. \n
  302. *Must be one of the following types: float16, float32, int32, int8, uint8
  303. *@par Attributes:
  304. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  305. *@par Outputs:
  306. *var: A Tensor. Has the same type and format as input "var".
  307. */
  308. REG_OP(ScatterAdd)
  309. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  310. .INPUT(indices, TensorType::IndexNumberType())
  311. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  312. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  313. .ATTR(use_locking, Bool, false)
  314. .OP_END_FACTORY_REG(ScatterAdd)
  315. /**
  316. *@brief Divides a variable reference by sparse updates.
  317. *@par Inputs:
  318. * Three inputs, including:
  319. *@li var: An NCHW, NHWC, or ND Tensor. \n
  320. *Must be one of the following types: float16, float32, int32, int8, uint8
  321. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  322. *Must be one of the following types: int32
  323. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  324. *Must be one of the following types: float16, float32, int32, int8, uint8
  325. *@par Attributes:
  326. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  327. *@li isRef: An optional bool. Defaults to "True"
  328. *@par Outputs:
  329. *var: A Tensor. Has the same type and format as input "var".
  330. */
  331. REG_OP(ScatterDiv)
  332. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  333. .INPUT(indices, TensorType({DT_INT32}))
  334. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  335. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  336. .ATTR(use_locking, Bool, false)
  337. .OP_END_FACTORY_REG(ScatterDiv)
  338. /**
  339. *@brief Applies sparse addition to individual values or slices in a Variable.
  340. *@par Inputs:
  341. * Three inputs, including:
  342. *@li var: An ND Tensor. \n
  343. *Must be one of the following types: float16, float32, int32, int8, uint8
  344. *@li indices: An ND Tensor. \n
  345. *Must be one of the following types: int32
  346. *@li updates: An ND Tensor. \n
  347. *Must be one of the following types: float16, float32, int32, int8, uint8
  348. *@par Attributes:
  349. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  350. *@par Outputs:
  351. *var: A Tensor. Has the same type and format as input "var".
  352. */
  353. REG_OP(ScatterNdAdd)
  354. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  355. .INPUT(indices, TensorType::IndexNumberType())
  356. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  357. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  358. .ATTR(use_locking, Bool, false)
  359. .OP_END_FACTORY_REG(ScatterNdAdd)
  360. /**
  361. *@brief Applies sparse addition to individual values or slices in a Variable.
  362. *@par Inputs:
  363. * Three inputs, including:
  364. *@li x: An ND Tensor. \n
  365. *Must be one of the following types: float16, float32, int32, int8, uint8
  366. *@li indices: An ND Tensor. \n
  367. *Must be one of the following types: int32
  368. *@li updates: An ND Tensor. \n
  369. *Must be one of the following types: float16, float32, int32, int8, uint8
  370. *@par Outputs:
  371. *y: A Tensor. Has the same type and format as input "x".
  372. */
  373. REG_OP(TensorScatterAdd)
  374. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  375. .INPUT(indices, TensorType::IndexNumberType())
  376. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  377. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  378. .OP_END_FACTORY_REG(TensorScatterAdd)
  379. /**
  380. *@brief Applies sparse subtraction to individual values or slices in a Variable.
  381. *@par Inputs:
  382. * Three inputs, including:
  383. *@li var: An ND Tensor. \n
  384. *Must be one of the following types: float16, float32, int32, int8, uint8
  385. *@li indices: An ND Tensor. \n
  386. *Must be one of the following types: int32
  387. *@li updates: An ND Tensor. \n
  388. *Must be one of the following types: float16, float32, int32, int8, uint8
  389. *@par Attributes:
  390. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  391. *@par Outputs:
  392. *var: A Tensor. Has the same type and format as input "var".
  393. */
  394. REG_OP(ScatterNdSub)
  395. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  396. .INPUT(indices, TensorType::IndexNumberType())
  397. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  398. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  399. .ATTR(use_locking, Bool, false)
  400. .OP_END_FACTORY_REG(ScatterNdSub)
  401. /**
  402. *@brief Applies sparse addition to individual values or slices in a Variable.
  403. *@par Inputs:
  404. * Three inputs, including:
  405. *@li x: An ND Tensor. \n
  406. *Must be one of the following types: float16, float32, int32, int8, uint8
  407. *@li indices: An ND Tensor. \n
  408. *Must be one of the following types: int32
  409. *@li updates: An ND Tensor. \n
  410. *Must be one of the following types: float16, float32, int32, int8, uint8
  411. *@par Outputs:
  412. *y: A Tensor. Has the same type and format as input "x".
  413. */
  414. REG_OP(TensorScatterSub)
  415. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  416. .INPUT(indices, TensorType::IndexNumberType())
  417. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  418. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  419. .OP_END_FACTORY_REG(TensorScatterSub)
  420. /**
  421. *@brief Subtracts sparse updates to a variable reference.
  422. *@par Inputs:
  423. * Three inputs, including:
  424. *@li var: An ND Tensor. \n
  425. *Must be one of the following types: float16, float32, int32, int8, uint8
  426. *@li indices: An ND Tensor. \n
  427. *Must be one of the following types: int32
  428. *@li updates: An ND Tensor. \n
  429. *Must be one of the following types: float16, float32, int32, int8, uint8
  430. *@par Attributes:
  431. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  432. *@par Outputs:
  433. *var: A Tensor. Has the same type and format as input "var".
  434. */
  435. REG_OP(ScatterSub)
  436. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  437. .INPUT(indices, TensorType::IndexNumberType())
  438. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  439. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  440. .ATTR(use_locking, Bool, false)
  441. .OP_END_FACTORY_REG(ScatterSub)
  442. /**
  443. *@brief: Returns the batched diagonal part of a batched tensor with "assist".
  444. *@par Inputs:
  445. * Two inputs, including:
  446. * @li x: A Tensor of type float16, float32, or int32.
  447. * @li assist: A Tensor of the same type as "x".
  448. *@par Outputs:
  449. *y: A Tensor. Has the same type as "x".
  450. */
  451. REG_OP(DiagPartD)
  452. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  453. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  454. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  455. .OP_END_FACTORY_REG(DiagPartD)
  456. /**
  457. *@brief: Returns the batched diagonal part of a batched tensor.
  458. *@par Inputs:\n
  459. *x: A Tensor. Must be one of the following types: float16, float32, int32, int64, double, complex64, complex128.
  460. *@par Outputs:
  461. *y: A Tensor. Has the same type as "x".
  462. */
  463. REG_OP(DiagPart)
  464. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  465. DT_COMPLEX64, DT_COMPLEX128}))
  466. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  467. DT_COMPLEX64, DT_COMPLEX128}))
  468. .OP_END_FACTORY_REG(DiagPart)
  469. /**
  470. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases.
  471. *@par Inputs:
  472. * Four inputs, including:
  473. *@li x: A Tensor of type float16, int8.
  474. *@li w: A weight matrix of type float16, int8.
  475. *@li b: A Tensor of type float16, int32, float32.
  476. *@li offset_w: A Tensor of type int8.
  477. *@par Attributes:
  478. *@li num_output: Reserved.
  479. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  480. *@li axis: Reserved.
  481. *@li offset_x: Reserved.
  482. *@par Outputs:
  483. *y: The result tensor of type float16, int8, float32.
  484. *@par Quantization supported or not
  485. * Yes
  486. */
  487. REG_OP(FullyConnection)
  488. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  489. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  490. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  491. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  492. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  493. .REQUIRED_ATTR(num_output, Int)
  494. .ATTR(transpose, Bool, false)
  495. .ATTR(axis, Int, 1)
  496. .ATTR(offset_x, Int, 0)
  497. .OP_END_FACTORY_REG(FullyConnection)
  498. /**
  499. *@brief Computes the confusion matrix from predictions and labels.
  500. *@par Inputs:
  501. * Three inputs, including:
  502. *@li labels: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  503. *@li predictions: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  504. *@li weights: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  505. *@par Attributes:
  506. *@li num_classes: An integer for the shape of the output matrix. No default value.
  507. *@li dtype: Data type of the confusion matrix. No default value.
  508. *@par Outputs:
  509. *y: A Tensor. Has the same type and format as input "labels"
  510. *@attention Constraints:
  511. *@li "weights", "labels", and "predictions" are 1D tensors.
  512. *@li The output is with shape (num_classes, num_classes), where, 1 <= num_classes <= 4096.
  513. *@see Region()
  514. */
  515. REG_OP(ConfusionMatrix)
  516. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  517. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  518. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  519. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  520. .REQUIRED_ATTR(num_classes, Int)
  521. .REQUIRED_ATTR(dtype, String)
  522. .OP_END_FACTORY_REG(ConfusionMatrix)
  523. /**
  524. *@brief Multiplies sparse updates into a variable reference.
  525. *@par Inputs:
  526. * Three inputs, including:
  527. *@li var: An NCHW, NHWC, or ND Tensor. \n
  528. *Must be one of the following types: float16, float32, int32, int8, uint8
  529. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  530. *Must be one of the following types: int32
  531. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  532. *Must be one of the following types: float16, float32, int32, int8, uint8
  533. *@par Attributes:
  534. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  535. *@par Outputs:
  536. *var: A Tensor. Has the same type and format as input "var".
  537. */
  538. REG_OP(ScatterMul)
  539. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  540. .INPUT(indices, TensorType({DT_INT32}))
  541. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  542. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  543. .ATTR(use_locking, Bool, false)
  544. .OP_END_FACTORY_REG(ScatterMul)
  545. /**
  546. *@brief Reduces sparse updates into a variable reference using the "min" operation.
  547. *@par Inputs:
  548. * Three inputs, including:
  549. *@li var: An NCHW, NHWC, or ND Tensor. \n
  550. *Must be one of the following types: float16, float32, int32
  551. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  552. *Must be one of the following types: int32
  553. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  554. *Must be one of the following types: float16, float32, int32
  555. *@par Attributes:
  556. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  557. *@par Outputs:
  558. *var: A Tensor. Has the same type and format as input "var".
  559. */
  560. REG_OP(ScatterMin)
  561. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  562. .INPUT(indices, TensorType({DT_INT32}))
  563. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  564. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  565. .ATTR(use_locking, Bool, false)
  566. .OP_END_FACTORY_REG(ScatterMin)
  567. /**
  568. *@brief Reduces sparse updates into a variable reference using the "max" operation.
  569. *@par Inputs:
  570. * Three inputs, including:
  571. *@li var: An NCHW, NHWC, or ND Tensor. \n
  572. *Must be one of the following types: float16, float32, int32
  573. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  574. *Must be one of the following types: int32
  575. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  576. *Must be one of the following types: float16, float32, int32
  577. *@par Attributes:
  578. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  579. *@par Outputs:
  580. *var: A Tensor. Has the same type and format as input "var".
  581. */
  582. REG_OP(ScatterMax)
  583. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  584. .INPUT(indices, TensorType({DT_INT32}))
  585. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  586. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  587. .ATTR(use_locking, Bool, false)
  588. .OP_END_FACTORY_REG(ScatterMax)
  589. /**
  590. *@brief Applies sparse updates to a variable reference.
  591. *@par Inputs:
  592. * Three inputs, including:
  593. *@li var: An NCHW, NHWC, or ND Tensor. \n
  594. *Must be one of the following types: float16, float32, int32, int8, uint8
  595. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  596. *Must be one of the following types: int32
  597. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  598. *Must be one of the following types: float16, float32, int32, int8, uint8
  599. *@par Attributes:
  600. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  601. *@par Outputs:
  602. *var: A Tensor. Has the same type and format as input "var".
  603. */
  604. REG_OP(ScatterUpdate)
  605. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  606. .INPUT(indices, TensorType({DT_INT32}))
  607. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  608. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  609. .ATTR(use_locking, Bool, false)
  610. .OP_END_FACTORY_REG(ScatterUpdate)
  611. /**
  612. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input`.
  613. *@par Inputs:
  614. * Three inputs, including:
  615. *@li input: Rank `r` tensor where `r >= 2`. \n
  616. *@li k: \n
  617. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  618. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  619. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  620. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  621. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  622. *@par Outputs:
  623. *diagonal: The extracted diagonal(s).
  624. */
  625. REG_OP(MatrixDiagPartV2)
  626. .INPUT(input, TensorType::BasicType())
  627. .INPUT(k, TensorType({DT_INT32}))
  628. .INPUT(padding_value, TensorType::BasicType())
  629. .OUTPUT(diagonal, TensorType::BasicType())
  630. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  631. /**
  632. *@brief Returns a batched matrix tensor with new batched diagonal values.
  633. *@par Inputs:
  634. * Three inputs, including:
  635. *@li input: "Rank `r+1`, where `r >= 1`. \n
  636. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  637. *@li k:
  638. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  639. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  640. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  641. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  642. *@par Outputs:
  643. *output: Rank `r+1`, with `output.shape = input.shape`.
  644. */
  645. REG_OP(MatrixSetDiagV2)
  646. .INPUT(input, TensorType::BasicType())
  647. .INPUT(diagonal, TensorType::BasicType())
  648. .INPUT(k, TensorType({DT_INT32}))
  649. .OUTPUT(output, TensorType::BasicType())
  650. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  651. /**
  652. *@brief Returns a batched diagonal tensor with given batched diagonal values.
  653. *@par Inputs:
  654. * Five inputs, including:
  655. *@li diagonal: Rank `r`, where `r >= 1` \n
  656. *@li k:
  657. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  658. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  659. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  660. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  661. *@li num_rows:
  662. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  663. *the output matrix is a square matrix and infers the matrix size from k and the \n
  664. *innermost dimension of `diagonal`. \n
  665. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  666. *The number of columns of the output matrix. If it is not provided, the op \n
  667. *assumes the output matrix is a square matrix and infers the matrix size from \n
  668. *k and the innermost dimension of `diagonal`. \n
  669. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  670. *@par Outputs:
  671. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise.
  672. */
  673. REG_OP(MatrixDiagV2)
  674. .INPUT(diagonal, TensorType::BasicType())
  675. .INPUT(k, TensorType({DT_INT32}))
  676. .INPUT(num_rows, TensorType({DT_INT32}))
  677. .INPUT(num_cols, TensorType({DT_INT32}))
  678. .INPUT(padding_value, TensorType::BasicType())
  679. .OUTPUT(output, TensorType::BasicType())
  680. .OP_END_FACTORY_REG(MatrixDiagV2)
  681. } // namespace ge
  682. #endif // GE_OP_MATRIX_CALCULATION_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示