You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 33 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_MATRIX_CALCULATION_OPS_H
  17. #define GE_OP_MATRIX_CALCULATION_OPS_H
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  22. *@par Inputs:
  23. *Three inputs, including:
  24. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  25. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  26. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  27. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  28. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC].
  30. *@par Attributes:
  31. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  32. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  33. *@par Outputs:
  34. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  35. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  36. *@par Third-party framework compatibility
  37. * Compatible with the TensorFlow operator BatchMatmul.
  38. */
  39. REG_OP(MatMul)
  40. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  41. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  42. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  43. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  44. .ATTR(transpose_x1, Bool, false)
  45. .ATTR(transpose_x2, Bool, false)
  46. .OP_END_FACTORY_REG(MatMul)
  47. /**
  48. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  49. *@par Inputs:
  50. *Two inputs, including:
  51. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  52. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  53. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  54. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  55. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  56. * float32, int32. Has format [ND, NHWC].
  57. *@par Attributes:
  58. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  59. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  60. *@par Outputs:
  61. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  62. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  63. *@par Third-party framework compatibility
  64. * Compatible with the TensorFlow operator BatchMatmul.
  65. */
  66. REG_OP(MatMulV2)
  67. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  68. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  69. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  70. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  71. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  72. .ATTR(transpose_x1, Bool, false)
  73. .ATTR(transpose_x2, Bool, false)
  74. .ATTR(offset_x, Int, 0)
  75. .OP_END_FACTORY_REG(MatMulV2)
  76. /**
  77. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c.
  78. *@par Inputs:
  79. *Five inputs, including:
  80. *@li a: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. Has format [FRACTAL_NZ].
  81. *@li b: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. When type is int8, has format [FRACTAL_Z], \n otherwise has format [FRACTAL_NZ].
  82. *@li c: A matrix Tensor. 2D or higher. Must be one of the following types: \n float16, int32, float32. When type is int32, has format [ND], \n otherwise has format [FRACTAL_NZ].
  83. *@li alpha: A 1D Tensor. The shape of alpha is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  84. *@li beta: A 1D Tensor. The shape of beta is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  85. *@par Attributes:
  86. *Two attributes, including:
  87. *@li transpose_a: Optional. A bool.\n If True, changes the shape of "a" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  88. *@li transpose_b: Optional. A bool.\n If True, changes the shape of "b" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  89. *@par Outputs:
  90. *@out: The result matrix Tensor. 4D. Must be one of the following types:\n float16, float32, int32. Has format [FRACTAL_NZ].
  91. */
  92. REG_OP(Gemm)
  93. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  94. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  95. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  96. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  97. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  98. .OUTPUT(out, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  99. .ATTR(transpose_a, Bool, false)
  100. .ATTR(transpose_b, Bool, false)
  101. .OP_END_FACTORY_REG(Gemm)
  102. /**
  103. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  104. *@par Inputs:
  105. *Three inputs, including:
  106. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  107. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  108. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  109. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  110. *@par Attributes:
  111. *@li adj_x: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  112. *@li adj_y: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M].
  113. *@par Outputs:
  114. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  115. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2".
  116. *@par Third-party framework compatibility
  117. * Compatible with the TensorFlow operator BatchMatmul.
  118. */
  119. REG_OP(BatchMatMul)
  120. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  121. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  122. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  123. .ATTR(adj_x1, Bool, false)
  124. .ATTR(adj_x2, Bool, false)
  125. .OP_END_FACTORY_REG(BatchMatMul)
  126. REG_OP(MeanCCE)
  127. .INPUT(x, TensorType::ALL())
  128. .INPUT(indices, TensorType::ALL())
  129. .OUTPUT(y, TensorType::ALL())
  130. .ATTR(keep_dims, Bool, false)
  131. .ATTR(value1, ListInt, {})
  132. .ATTR(mode, Int, 3) // 0:max pooling or 1:avg pooling
  133. .ATTR(pad_mode, Int, 0)
  134. .ATTR(global_pooling, Bool, true) // tensorflow have no attr, set default value
  135. .ATTR(window, ListInt, {1,1}) // kernel size
  136. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  137. .ATTR(stride, ListInt, {1,1}) // stride size
  138. .ATTR(ceil_mode, Int, 0)
  139. .ATTR(data_mode, Int, 1)
  140. .ATTR(nan_opt, Int, 0)
  141. .ATTR(fomart, Int, 0)
  142. .OP_END_FACTORY_REG(MeanCCE)
  143. REG_OP(MeanGrad)
  144. .INPUT(x, TensorType::ALL())
  145. .OUTPUT(y, TensorType::ALL())
  146. .ATTR(mode, Int, 1) // 0:max pooling or 1:avg pooling
  147. .ATTR(pad_mode, Int, 0)
  148. .ATTR(global_pooling, Bool, false)
  149. .ATTR(window, ListInt, {1,1}) // kernel size
  150. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  151. .ATTR(stride, ListInt, {1,1}) // stride size
  152. .ATTR(ceil_mode, Int, 0)
  153. .ATTR(data_mode, Int, 1)
  154. .ATTR(nan_opt, Int, 0)
  155. .ATTR(mean_grad_output_shape_value, ListInt, {1,1,1,1})
  156. .ATTR(mean_grad_output_shape_format, Int, 1) //must be NHWC
  157. .OP_END_FACTORY_REG(MeanGrad)
  158. REG_OP(MatMulCCE)
  159. .INPUT(x1, TensorType({DT_FLOAT}))
  160. .INPUT(x2, TensorType({DT_FLOAT}))
  161. .OPTIONAL_INPUT(x3, TensorType({DT_FLOAT}))
  162. .OUTPUT(y, TensorType({DT_FLOAT}))
  163. .ATTR(transpose_a, Bool, false)
  164. .ATTR(transpose_b, Bool, false)
  165. .ATTR(has_bias, Bool, false)
  166. .OP_END_FACTORY_REG(MatMulCCE)
  167. /**
  168. *@brief Computes half the L2 norm of a tensor without the sqrt.
  169. *@par Inputs:
  170. * x: A Tensor.
  171. * TensorType::FloatingDataType().
  172. *@par Outputs:
  173. *y: A Tensor. Has the same type as "x".
  174. *@par Third-party framework compatibility
  175. *Compatible with the TensorFlow operator L2Loss.
  176. */
  177. REG_OP(L2Loss)
  178. .INPUT(x, TensorType::FloatingDataType())
  179. .OUTPUT(y, TensorType::FloatingDataType())
  180. .OP_END_FACTORY_REG(L2Loss)
  181. /**
  182. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  183. *@par Inputs:
  184. *x: A Tensor. Must be one of the following types:
  185. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  186. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  187. *@par Outputs:
  188. *y: A Tensor. Has the same type as "x".
  189. *@par Third-party framework compatibility
  190. * Compatible with the TensorFlow operator MatrixDiag.
  191. */
  192. REG_OP(MatrixDiag)
  193. .INPUT(x, TensorType::BasicType())
  194. .OUTPUT(y, TensorType::BasicType())
  195. .OP_END_FACTORY_REG(MatrixDiag)
  196. /**
  197. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  198. *@par Inputs:
  199. * Two inputs, including:
  200. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  201. *@li assist: A Tensor of the same type as "x".
  202. *@par Outputs:
  203. *y: A Tensor. Has the same type as "x".
  204. *@par Third-party framework compatibility
  205. * Compatible with the TensorFlow operator MatrixDiag.
  206. */
  207. REG_OP(MatrixDiagD)
  208. .INPUT(x, TensorType::BasicType())
  209. .INPUT(assist, TensorType::BasicType())
  210. .OUTPUT(y, TensorType::BasicType())
  211. .OP_END_FACTORY_REG(MatrixDiagD)
  212. /**
  213. *@brief: Returns the batched diagonal part of a batched tensor.
  214. *@par Inputs:
  215. *x: A Tensor. Must be one of the following types:
  216. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  217. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  218. *@par Outputs:
  219. *y: A Tensor. Has the same type as "x".
  220. *@par Third-party framework compatibility
  221. * Compatible with the TensorFlow operator MatrixDiagPart.
  222. */
  223. REG_OP(MatrixDiagPart)
  224. .INPUT(x, TensorType::BasicType())
  225. .OUTPUT(y, TensorType::BasicType())
  226. .OP_END_FACTORY_REG(MatrixDiagPart)
  227. /**
  228. *@brief: Returns the batched diagonal part of a batched tensor.
  229. *@par Inputs:
  230. * Two inputs, including:
  231. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  232. *@li assist: A Tensor of the same type as "x".
  233. *@par Outputs:
  234. *y: A Tensor. Has the same type as "x".
  235. *@par Third-party framework compatibility
  236. * Compatible with the TensorFlow operator MatrixDiagPart.
  237. */
  238. REG_OP(MatrixDiagPartD)
  239. .INPUT(x, TensorType::BasicType())
  240. .INPUT(assist, TensorType::BasicType())
  241. .OUTPUT(y, TensorType::BasicType())
  242. .OP_END_FACTORY_REG(MatrixDiagPartD)
  243. /**
  244. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  245. *@par Inputs:
  246. * Two inputs, including:
  247. *@li x: A Tensor. Must be one of the following types:
  248. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  249. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  250. *@li diagonal: A Tensor of the same type as "x".
  251. *@par Outputs:
  252. *y: A Tensor. Has the same type as "x".
  253. *@par Third-party framework compatibility
  254. * Compatible with the TensorFlow operator MatrixSetDiag.
  255. */
  256. REG_OP(MatrixSetDiag)
  257. .INPUT(x, TensorType::BasicType())
  258. .INPUT(diagonal, TensorType::BasicType())
  259. .OUTPUT(y, TensorType::BasicType())
  260. .OP_END_FACTORY_REG(MatrixSetDiag)
  261. /**
  262. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  263. *@par Inputs:
  264. * Three inputs, including:
  265. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  266. *@li diagonal: A Tensor of the same type as "x".
  267. *@li assist: A Tensor of the same type as "x".
  268. *@par Outputs:
  269. *y: A Tensor. Has the same type as "x".
  270. *@par Third-party framework compatibility
  271. * Compatible with the TensorFlow operator MatrixSetDiag.
  272. */
  273. REG_OP(MatrixSetDiagD)
  274. .INPUT(x, TensorType::BasicType())
  275. .INPUT(diagonal, TensorType::BasicType())
  276. .INPUT(assist, TensorType::BasicType())
  277. .OUTPUT(y, TensorType::BasicType())
  278. .OP_END_FACTORY_REG(MatrixSetDiagD)
  279. /**
  280. *@brief Applies sparse "updates" to individual values or slices in a Variable.
  281. *@par Inputs:
  282. * Three inputs, including:
  283. *@li var: An ND Tensor.
  284. *Must be one of the following types: float16, float32, int8, uint8, double,
  285. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  286. * uint64
  287. *@li indices: An ND Tensor.
  288. *Must be one of the following types: int32, int64
  289. *@li updates: An ND Tensor.
  290. *Must be one of the following types: float16, float32, int8, uint8, double,
  291. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  292. * uint64
  293. *@par Attributes:
  294. *use_locking: An optional bool. Defaults to "False". If "True",
  295. * the operation will be protected by a lock.
  296. *@par Outputs:
  297. *var: A Tensor. Has the same type and format as input "var".
  298. *@par Third-party framework compatibility
  299. * Compatible with the TensorFlow operator ScatterNdUpdate.
  300. */
  301. REG_OP(ScatterNdUpdate)
  302. .INPUT(var, TensorType::BasicType())
  303. .INPUT(indices, TensorType::IndexNumberType())
  304. .INPUT(updates, TensorType::BasicType())
  305. .OUTPUT(var, TensorType::BasicType())
  306. .ATTR(use_locking, Bool, false)
  307. .OP_END_FACTORY_REG(ScatterNdUpdate)
  308. /**
  309. *@brief Applies sparse addition to individual values or slices in a Variable.
  310. *@par Inputs:
  311. * Three inputs, including:
  312. *@li x: An ND Tensor. \n
  313. *Must be one of the following types: float16, float32, bool, int8, uint8
  314. *@li indices: An ND Tensor. \n
  315. *Must be one of the following types: int32
  316. *@li updates: An ND Tensor. \n
  317. *Must be one of the following types: float16, float32, bool, int8, uint8
  318. *@par Outputs:
  319. *y: A Tensor. Has the same type and format as input "x".
  320. *@par Third-party framework compatibility
  321. * Compatible with the TensorFlow operator TensorScatterUpdate.
  322. */
  323. REG_OP(TensorScatterUpdate)
  324. .INPUT(x, TensorType::BasicType())
  325. .INPUT(indices, TensorType::IndexNumberType())
  326. .INPUT(updates, TensorType::BasicType())
  327. .OUTPUT(y, TensorType::BasicType())
  328. .OP_END_FACTORY_REG(TensorScatterUpdate)
  329. /**
  330. *@brief Adds sparse "updates" to a variable reference.
  331. *@par Inputs:
  332. * Three inputs, including:
  333. *@li var: An ND Tensor.
  334. *Must be one of the following types: float16, float32, int32, int8, uint8
  335. *@li indices: An ND Tensor of type int32 or int64.
  336. *@li updates: An Tensor. format:NCHW, NHWC.
  337. *Must be one of the following types: float16, float32, int32, int8, uint8
  338. *@par Attributes:
  339. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  340. * will be protected by a lock.
  341. *@par Outputs:
  342. *var: A Tensor. Has the same type and format as input "var".
  343. *@par Third-party framework compatibility
  344. * Compatible with the TensorFlow operator ScatterAdd.
  345. */
  346. REG_OP(ScatterAdd)
  347. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  348. .INPUT(indices, TensorType::IndexNumberType())
  349. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  350. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  351. .ATTR(use_locking, Bool, false)
  352. .OP_END_FACTORY_REG(ScatterAdd)
  353. /**
  354. *@brief Divides a variable reference by sparse updates.
  355. *@par Inputs:
  356. * Three inputs, including:
  357. *@li var: An ND Tensor.
  358. *Must be one of the following types: float16, float, int32, int8, uint8
  359. *@li indices: An ND Tensor.
  360. *Must be one of the following types: int32
  361. *@li updates: An ND Tensor.
  362. *Must be one of the following types: float16, float, int32, int8, uint8
  363. *@par Attributes:
  364. *@li use_locking: An optional bool. Defaults to "False". If "True",
  365. * the operation will be protected by a lock.
  366. *@par Outputs:
  367. *var: A Tensor. Has the same type and format as input "var".
  368. *@par Third-party framework compatibility
  369. * Compatible with the TensorFlow operator ScatterDiv.
  370. */
  371. REG_OP(ScatterDiv)
  372. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  373. .INPUT(indices, TensorType({DT_INT32}))
  374. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  375. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  376. .ATTR(use_locking, Bool, false)
  377. .OP_END_FACTORY_REG(ScatterDiv)
  378. /**
  379. *@brief Applies sparse addition to individual values or slices in a Variable.
  380. *@par Inputs:
  381. * Three inputs, including:
  382. *@li var: An ND Tensor.
  383. *Must be one of the following types: float16, float, int32, int8, uint8
  384. *@li indices: An ND Tensor.
  385. *Must be one of the following types: int32
  386. *@li updates: An ND Tensor.
  387. *Must be one of the following types: float16, float, int32, int8, uint8
  388. *@par Attributes:
  389. *use_locking: An optional bool. Defaults to "False". If "True",
  390. * the operation will be protected by a lock.
  391. *@par Outputs:
  392. *var: A Tensor. Has the same type and format as input "var".
  393. *@par Third-party framework compatibility
  394. * Compatible with the TensorFlow operator ScatterNdAdd.
  395. */
  396. REG_OP(ScatterNdAdd)
  397. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  398. .INPUT(indices, TensorType::IndexNumberType())
  399. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  400. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  401. .ATTR(use_locking, Bool, false)
  402. .OP_END_FACTORY_REG(ScatterNdAdd)
  403. /**
  404. *@brief Applies sparse addition to individual values or slices in a Variable.
  405. *@par Inputs:
  406. * Three inputs, including:
  407. *@li x: An ND Tensor. \n
  408. *Must be one of the following types: float16, float32, int32, int8, uint8
  409. *@li indices: An ND Tensor. \n
  410. *Must be one of the following types: int32
  411. *@li updates: An ND Tensor. \n
  412. *Must be one of the following types: float16, float32, int32, int8, uint8
  413. *@par Outputs:
  414. *y: A Tensor. Has the same type and format as input "x".
  415. *@par Third-party framework compatibility
  416. * Compatible with the TensorFlow operator TensorScatterAdd.
  417. */
  418. REG_OP(TensorScatterAdd)
  419. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  420. .INPUT(indices, TensorType::IndexNumberType())
  421. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  422. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  423. .OP_END_FACTORY_REG(TensorScatterAdd)
  424. /**
  425. *@brief Applies sparse subtraction to individual values or slices in a Variable.
  426. *@par Inputs:
  427. * Three inputs, including:
  428. *@li var: An ND Tensor.
  429. *Must be one of the following types: float16, float, int32, int8, uint8
  430. *@li indices: An ND Tensor.
  431. *Must be one of the following types: int32, int64
  432. *@li updates: An ND Tensor.
  433. *Must be one of the following types: float16, float, int32, int8, uint8
  434. *@par Attributes:
  435. *use_locking: An optional bool. Defaults to "False". If "True",
  436. * the operation will be protected by a lock.
  437. *@par Outputs:
  438. *var: A Tensor. Has the same type and format as input "var".
  439. *@par Third-party framework compatibility
  440. * Compatible with the TensorFlow operator ScatterNdSub.
  441. */
  442. REG_OP(ScatterNdSub)
  443. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  444. .INPUT(indices, TensorType::IndexNumberType())
  445. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  446. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  447. .ATTR(use_locking, Bool, false)
  448. .OP_END_FACTORY_REG(ScatterNdSub)
  449. /**
  450. *@brief Applies sparse addition to individual values or slices in a Variable.
  451. *@par Inputs:
  452. * Three inputs, including:
  453. *@li x: An ND Tensor. \n
  454. *Must be one of the following types: float16, float32, int32, int8, uint8
  455. *@li indices: An ND Tensor. \n
  456. *Must be one of the following types: int32
  457. *@li updates: An ND Tensor. \n
  458. *Must be one of the following types: float16, float32, int32, int8, uint8
  459. *@par Outputs:
  460. *y: A Tensor. Has the same type and format as input "x".
  461. *@par Third-party framework compatibility
  462. * Compatible with the TensorFlow operator TensorScatterSub.
  463. */
  464. REG_OP(TensorScatterSub)
  465. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  466. .INPUT(indices, TensorType::IndexNumberType())
  467. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  468. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  469. .OP_END_FACTORY_REG(TensorScatterSub)
  470. /**
  471. *@brief Subtracts sparse updates to a variable reference.
  472. *@par Inputs:
  473. * Three inputs, including:
  474. *@li var: An ND Tensor.
  475. *Must be one of the following types: float16, float, int32, int8, uint8
  476. *@li indices: An ND Tensor.
  477. *Must be one of the following types: int32, int64
  478. *@li updates: An ND Tensor.
  479. *Must be one of the following types: float16, float, int32, int8, uint8
  480. *@par Attributes:
  481. *use_locking: An optional bool. Defaults to "False". If "True",
  482. * the operation will be protected by a lock.
  483. *@par Outputs:
  484. *var: A Tensor. Has the same type and format as input "var".
  485. *@par Third-party framework compatibility
  486. * Compatible with the TensorFlow operator ScatterSub.
  487. */
  488. REG_OP(ScatterSub)
  489. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  490. .INPUT(indices, TensorType::IndexNumberType())
  491. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  492. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  493. .ATTR(use_locking, Bool, false)
  494. .OP_END_FACTORY_REG(ScatterSub)
  495. /**
  496. *@brief: Returns the batched diagonal part of a batched tensor with "assist".
  497. *@par Inputs:
  498. * Two inputs, including:
  499. * @li x: A Tensor of type float16, float32, or int32.
  500. * @li assist: A Tensor of the same type as "x".
  501. *@par Outputs:
  502. *y: A Tensor. Has the same type as "x".
  503. *@par Third-party framework compatibility
  504. * Compatible with the TensorFlow operator DiagPart.
  505. */
  506. REG_OP(DiagPartD)
  507. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  508. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  509. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  510. .OP_END_FACTORY_REG(DiagPartD)
  511. /**
  512. *@brief: Returns the batched diagonal part of a batched tensor.
  513. *@par Inputs:
  514. *x: A Tensor. Must be one of the following types:
  515. * float16, float32, int32, int64, double, complex64, complex128.
  516. *@par Outputs:
  517. *y: A Tensor. Has the same type as "x".
  518. *@par Third-party framework compatibility
  519. * Compatible with the TensorFlow operator DiagPart.
  520. */
  521. REG_OP(DiagPart)
  522. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  523. DT_COMPLEX64, DT_COMPLEX128}))
  524. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  525. DT_COMPLEX64, DT_COMPLEX128}))
  526. .OP_END_FACTORY_REG(DiagPart)
  527. /**
  528. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases.
  529. *@par Inputs:
  530. * Four inputs, including:
  531. *@li x: A Tensor of type float16, int8.
  532. *@li w: A weight matrix of type float16, int8.
  533. *@li b: A Tensor of type float16, int32, float32.
  534. *@li offset_w: A Tensor of type int8.
  535. *@par Attributes:
  536. *@li num_output: Reserved.
  537. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  538. *@li axis: Optional. A int. 1 or 2.
  539. *@li offset_x: Reserved.
  540. *@par Outputs:
  541. *y: The result tensor of type float16, int32, float32.
  542. *@par Third-party framework compatibility
  543. * Compatible with the Caffe operator InnerProduct.
  544. *@par Quantization supported or not
  545. * Yes
  546. */
  547. REG_OP(FullyConnection)
  548. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  549. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  550. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  551. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  552. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  553. .REQUIRED_ATTR(num_output, Int)
  554. .ATTR(transpose, Bool, false)
  555. .ATTR(axis, Int, 1)
  556. .ATTR(offset_x, Int, 0)
  557. .OP_END_FACTORY_REG(FullyConnection)
  558. /**
  559. *@brief Computes the confusion matrix from predictions and labels.
  560. *@par Inputs:
  561. * Three inputs, including:
  562. *@li labels: A Tensor. Must be one of the following types: float16, float32,
  563. * int32, int8, uint8.
  564. *@li predictions: A Tensor. Must be one of the following types: float16,
  565. * float32, int32, int8, uint8.
  566. *@li weights: A Tensor. Must be one of the following types: float16, float32,
  567. * int32, int8, uint8.
  568. *@par Attributes:
  569. *@li num_classes: An integer for the shape of the output matrix.
  570. * No default value.
  571. *@li dtype: Data type of the confusion matrix. No default value.
  572. *@par Outputs:
  573. *y: A Tensor. Has the same type and format as input "labels"
  574. *@attention Constraints:
  575. *@li "weights", "labels", and "predictions" are 1D tensors.
  576. *@li The output is with shape (num_classes, num_classes),
  577. * where, 1 <= num_classes <= 4096.
  578. *@see Region()
  579. *@par Third-party framework compatibility
  580. * Compatible with the TensorFlow operator ConfusionMatrix.
  581. */
  582. REG_OP(ConfusionMatrix)
  583. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  584. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  585. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  586. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  587. .REQUIRED_ATTR(num_classes, Int)
  588. .REQUIRED_ATTR(dtype, String)
  589. .OP_END_FACTORY_REG(ConfusionMatrix)
  590. /**
  591. *@brief Multiplies sparse updates into a variable reference.
  592. *@par Inputs:
  593. * Three inputs, including:
  594. *@li var: An ND Tensor.
  595. *Must be one of the following types: float16, float, int32, int8, uint8
  596. *@li indices: An ND Tensor.
  597. *Must be one of the following types: int32
  598. *@li updates: An ND Tensor.
  599. *Must be one of the following types: float16, float, int32, int8, uint8
  600. *@par Attributes:
  601. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  602. * will be protected by a lock.
  603. *@par Outputs:
  604. *var: A Tensor. Has the same type and format as input "var".
  605. *@par Third-party framework compatibility
  606. * Compatible with the TensorFlow operator ScatterMul.
  607. */
  608. REG_OP(ScatterMul)
  609. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  610. .INPUT(indices, TensorType({DT_INT32}))
  611. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  612. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  613. .ATTR(use_locking, Bool, false)
  614. .OP_END_FACTORY_REG(ScatterMul)
  615. /**
  616. *@brief Reduces sparse updates into a variable reference using
  617. * the "min" operation.
  618. *@par Inputs:
  619. * Three inputs, including:
  620. *@li var: An ND Tensor.
  621. *Must be one of the following types: float16, float, int32
  622. *@li indices: An ND Tensor.
  623. *Must be one of the following types: int32
  624. *@li updates: An ND Tensor.
  625. *Must be one of the following types: float16, float, int32
  626. *@par Attributes:
  627. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  628. * will be protected by a lock.
  629. *@par Outputs:
  630. *var: A Tensor. Has the same type and format as input "var".
  631. *@par Third-party framework compatibility
  632. * Compatible with the TensorFlow operator ScatterMin.
  633. */
  634. REG_OP(ScatterMin)
  635. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  636. .INPUT(indices, TensorType({DT_INT32}))
  637. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  638. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  639. .ATTR(use_locking, Bool, false)
  640. .OP_END_FACTORY_REG(ScatterMin)
  641. /**
  642. *@brief Reduces sparse updates into a variable reference using the "max" operation.
  643. *@par Inputs:
  644. * Three inputs, including:
  645. *@li var: An ND Tensor.
  646. *Must be one of the following types: float16, float, int32
  647. *@li indices: An NCHW, NHWC, or ND Tensor.
  648. *Must be one of the following types: int32
  649. *@li updates: An NCHW, NHWC, or ND Tensor.
  650. *Must be one of the following types: float16, float, int32
  651. *@par Attributes:
  652. *use_locking: An optional bool. Defaults to "False".
  653. * If "True", the operation will be protected by a lock.
  654. *@par Outputs:
  655. *var: A Tensor. Has the same type and format as input "var".
  656. *@par Third-party framework compatibility
  657. * Compatible with the TensorFlow operator ScatterMax.
  658. */
  659. REG_OP(ScatterMax)
  660. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  661. .INPUT(indices, TensorType({DT_INT32}))
  662. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  663. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  664. .ATTR(use_locking, Bool, false)
  665. .OP_END_FACTORY_REG(ScatterMax)
  666. /**
  667. *@brief Applies sparse updates to a variable reference.
  668. *@par Inputs:
  669. * Three inputs, including:
  670. *@li var: An ND Tensor.
  671. *Must be one of the following types: float16, float, int32, int8, uint8
  672. *@li indices: An ND Tensor.
  673. *Must be one of the following types: int32
  674. *@li updates: An ND Tensor.
  675. *Must be one of the following types: float16, float, int32, int8, uint8
  676. *@par Attributes:
  677. *use_locking: An optional bool. Defaults to "False". If "True",
  678. * the operation will be protected by a lock.
  679. *@par Outputs:
  680. *var: A Tensor. Has the same type and format as input "var".
  681. *@par Third-party framework compatibility
  682. * Compatible with the TensorFlow operator ScatterUpdate.
  683. */
  684. REG_OP(ScatterUpdate)
  685. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  686. .INPUT(indices, TensorType({DT_INT32}))
  687. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  688. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  689. .ATTR(use_locking, Bool, false)
  690. .OP_END_FACTORY_REG(ScatterUpdate)
  691. /**
  692. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input`.
  693. *@par Inputs:
  694. * Three inputs, including:
  695. *@li input: Rank `r` tensor where `r >= 2`. \n
  696. *@li k: \n
  697. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  698. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  699. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  700. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  701. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  702. *@par Outputs:
  703. *diagonal: The extracted diagonal(s).
  704. *@par Third-party framework compatibility
  705. * Compatible with the TensorFlow operator ScatterUpdate.
  706. */
  707. REG_OP(MatrixDiagPartV2)
  708. .INPUT(input, TensorType::BasicType())
  709. .INPUT(k, TensorType({DT_INT32}))
  710. .INPUT(padding_value, TensorType::BasicType())
  711. .OUTPUT(diagonal, TensorType::BasicType())
  712. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  713. /**
  714. *@brief Returns a batched matrix tensor with new batched diagonal values.
  715. *@par Inputs:
  716. * Three inputs, including:
  717. *@li input: "Rank `r+1`, where `r >= 1`. \n
  718. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  719. *@li k:
  720. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  721. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  722. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  723. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  724. *@par Outputs:
  725. *output: Rank `r+1`, with `output.shape = input.shape`.
  726. *@par Third-party framework compatibility
  727. * Compatible with the TensorFlow operator ScatterUpdate.
  728. */
  729. REG_OP(MatrixSetDiagV2)
  730. .INPUT(input, TensorType::BasicType())
  731. .INPUT(diagonal, TensorType::BasicType())
  732. .INPUT(k, TensorType({DT_INT32}))
  733. .OUTPUT(output, TensorType::BasicType())
  734. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  735. /**
  736. *@brief Returns a batched diagonal tensor with given batched diagonal values.
  737. *@par Inputs:
  738. * Five inputs, including:
  739. *@li diagonal: Rank `r`, where `r >= 1` \n
  740. *@li k:
  741. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  742. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  743. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  744. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  745. *@li num_rows:
  746. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  747. *the output matrix is a square matrix and infers the matrix size from k and the \n
  748. *innermost dimension of `diagonal`. \n
  749. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  750. *The number of columns of the output matrix. If it is not provided, the op \n
  751. *assumes the output matrix is a square matrix and infers the matrix size from \n
  752. *k and the innermost dimension of `diagonal`. \n
  753. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  754. *@par Outputs:
  755. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise.
  756. *@par Third-party framework compatibility
  757. * Compatible with the TensorFlow operator ScatterUpdate.
  758. */
  759. REG_OP(MatrixDiagV2)
  760. .INPUT(diagonal, TensorType::BasicType())
  761. .INPUT(k, TensorType({DT_INT32}))
  762. .INPUT(num_rows, TensorType({DT_INT32}))
  763. .INPUT(num_cols, TensorType({DT_INT32}))
  764. .INPUT(padding_value, TensorType::BasicType())
  765. .OUTPUT(output, TensorType::BasicType())
  766. .OP_END_FACTORY_REG(MatrixDiagV2)
  767. } // namespace ge
  768. #endif // GE_OP_MATRIX_CALCULATION_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示