You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef GE_OP_MATRIX_CALCULATION_OPS_H
  21. #define GE_OP_MATRIX_CALCULATION_OPS_H
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  26. *@par Inputs:
  27. *Three inputs, including:
  28. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  30. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  31. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  32. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  33. * float32, int32. Has format [ND, NHWC].
  34. *@par Attributes:
  35. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  36. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  37. *@par Outputs:
  38. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  39. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  40. *@par Third-party framework compatibility
  41. * Compatible with the TensorFlow operator BatchMatmul.
  42. */
  43. REG_OP(MatMul)
  44. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  45. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  46. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  47. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  48. .ATTR(transpose_x1, Bool, false)
  49. .ATTR(transpose_x2, Bool, false)
  50. .OP_END_FACTORY_REG(MatMul)
  51. /**
  52. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  53. *@par Inputs:
  54. *Two inputs, including:
  55. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  56. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  57. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  58. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  59. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  60. * float32, int32. Has format [ND, NHWC].
  61. *@par Attributes:
  62. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  63. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  64. *@par Outputs:
  65. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  66. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  67. *@par Third-party framework compatibility
  68. * Compatible with the TensorFlow operator BatchMatmul.
  69. */
  70. REG_OP(MatMulV2)
  71. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  72. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  73. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  74. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  75. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  76. .ATTR(transpose_x1, Bool, false)
  77. .ATTR(transpose_x2, Bool, false)
  78. .ATTR(offset_x, Int, 0)
  79. .OP_END_FACTORY_REG(MatMulV2)
  80. /**
  81. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c.
  82. *@par Inputs:
  83. *Five inputs, including:
  84. *@li a: A matrix Tensor. Must be one of the following types: float16, int8.
  85. * Has format [ND, FRACTAL_NZ]. 2D(ND) or 4D(FRACTAL_NZ).
  86. *@li b: A matrix Tensor. Must be one of the following types: float16, int8.
  87. * Has format [ND, FRACTAL_NZ, FRACTAL_Z]. 2D(ND) or 4D(FRACTAL_NZ, FRACTAL_Z).
  88. *@li c: A matrix Tensor. Must be one of the following types: float16, int32,
  89. * float32. has format [ND, FRACTAL_NZ]. 2D(ND) or 4D(FRACTAL_NZ).
  90. *@li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the following
  91. * types: float16, int32, float32. Has format [ND].
  92. *@li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  93. * types: float16, int32, float32. Has format [ND].
  94. * The format of a, b, c has restriction:\n
  95. * When type of a is int8 and type of c is int32, the format of a, b, c should
  96. * all be ND, or a is FRACTAL_NZ and b is FRACTAL_Z and c is ND.\n
  97. * When type of a is int8 and type of c is float32, the format of a, b, c should
  98. * all be ND or a is FRACTAL_NZ and b is FRACTAL_Z and c is FRACTAL_NZ.\n
  99. * When type of a is float16 and type of c is float16, the format of a, b, c
  100. * should all be ND or FRACTAL_NZ.\n
  101. * When type of a is float16 and type of c is float32, the format of a, b, c
  102. * should all be ND or FRACTAL_NZ.
  103. *@par Attributes:
  104. *Two attributes, including:
  105. *@li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  106. * [M, K] to [K, M].
  107. *@li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  108. * [K, N] to [N, K].
  109. *@par Outputs:
  110. *y: The result matrix Tensor. Must be one of the following types: float16,
  111. * float32, int32. Has format [ND, FRACTAL_NZ], the format should be equal to a.
  112. * 2D(ND) or 4D(FRACTAL_NZ).
  113. */
  114. REG_OP(GEMM)
  115. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  116. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  117. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  118. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  119. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  120. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  121. .ATTR(transpose_a, Bool, false)
  122. .ATTR(transpose_b, Bool, false)
  123. .OP_END_FACTORY_REG(GEMM)
  124. /**
  125. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  126. *@par Inputs:
  127. *Three inputs, including:
  128. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  129. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  130. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  131. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  132. *@par Attributes:
  133. *@li adj_x: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  134. *@li adj_y: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M].
  135. *@par Outputs:
  136. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  137. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2".
  138. *@par Third-party framework compatibility
  139. * Compatible with the TensorFlow operator BatchMatmul.
  140. */
  141. REG_OP(BatchMatMul)
  142. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  143. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  144. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  145. .ATTR(adj_x1, Bool, false)
  146. .ATTR(adj_x2, Bool, false)
  147. .OP_END_FACTORY_REG(BatchMatMul)
  148. /**
  149. *@brief Computes half the L2 norm of a tensor without the sqrt.
  150. *@par Inputs:
  151. * x: A Tensor.
  152. * TensorType::FloatingDataType().
  153. *@par Outputs:
  154. *y: A Tensor. Has the same type as "x".
  155. *@par Third-party framework compatibility
  156. *Compatible with the TensorFlow operator L2Loss.
  157. */
  158. REG_OP(L2Loss)
  159. .INPUT(x, TensorType::FloatingDataType())
  160. .OUTPUT(y, TensorType::FloatingDataType())
  161. .OP_END_FACTORY_REG(L2Loss)
  162. /**
  163. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  164. *@par Inputs:
  165. *x: A Tensor. Must be one of the following types:
  166. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  167. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  168. *@par Outputs:
  169. *y: A Tensor. Has the same type as "x".
  170. *@par Third-party framework compatibility
  171. * Compatible with the TensorFlow operator MatrixDiag.
  172. */
  173. REG_OP(MatrixDiag)
  174. .INPUT(x, TensorType::BasicType())
  175. .OUTPUT(y, TensorType::BasicType())
  176. .OP_END_FACTORY_REG(MatrixDiag)
  177. /**
  178. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  179. *@par Inputs:
  180. * Two inputs, including:
  181. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  182. *@li assist: A Tensor of the same type as "x".
  183. *@par Outputs:
  184. *y: A Tensor. Has the same type as "x".
  185. *@par Third-party framework compatibility
  186. * Compatible with the TensorFlow operator MatrixDiag.
  187. */
  188. REG_OP(MatrixDiagD)
  189. .INPUT(x, TensorType::BasicType())
  190. .INPUT(assist, TensorType::BasicType())
  191. .OUTPUT(y, TensorType::BasicType())
  192. .OP_END_FACTORY_REG(MatrixDiagD)
  193. /**
  194. *@brief: Returns the batched diagonal part of a batched tensor.
  195. *@par Inputs:
  196. *x: A Tensor. Must be one of the following types:
  197. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  198. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  199. *@par Outputs:
  200. *y: A Tensor. Has the same type as "x".
  201. *@par Third-party framework compatibility
  202. * Compatible with the TensorFlow operator MatrixDiagPart.
  203. */
  204. REG_OP(MatrixDiagPart)
  205. .INPUT(x, TensorType::BasicType())
  206. .OUTPUT(y, TensorType::BasicType())
  207. .OP_END_FACTORY_REG(MatrixDiagPart)
  208. /**
  209. *@brief: Returns the batched diagonal part of a batched tensor.
  210. *@par Inputs:
  211. * Two inputs, including:
  212. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  213. *@li assist: A Tensor of the same type as "x".
  214. *@par Outputs:
  215. *y: A Tensor. Has the same type as "x".
  216. *@par Third-party framework compatibility
  217. * Compatible with the TensorFlow operator MatrixDiagPart.
  218. */
  219. REG_OP(MatrixDiagPartD)
  220. .INPUT(x, TensorType::BasicType())
  221. .INPUT(assist, TensorType::BasicType())
  222. .OUTPUT(y, TensorType::BasicType())
  223. .OP_END_FACTORY_REG(MatrixDiagPartD)
  224. /**
  225. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  226. *@par Inputs:
  227. * Two inputs, including:
  228. *@li x: A Tensor. Must be one of the following types:
  229. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  230. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  231. *@li diagonal: A Tensor of the same type as "x".
  232. *@par Outputs:
  233. *y: A Tensor. Has the same type as "x".
  234. *@par Third-party framework compatibility
  235. * Compatible with the TensorFlow operator MatrixSetDiag.
  236. */
  237. REG_OP(MatrixSetDiag)
  238. .INPUT(x, TensorType::BasicType())
  239. .INPUT(diagonal, TensorType::BasicType())
  240. .OUTPUT(y, TensorType::BasicType())
  241. .OP_END_FACTORY_REG(MatrixSetDiag)
  242. /**
  243. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  244. *@par Inputs:
  245. * Three inputs, including:
  246. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  247. *@li diagonal: A Tensor of the same type as "x".
  248. *@li assist: A Tensor of the same type as "x".
  249. *@par Outputs:
  250. *y: A Tensor. Has the same type as "x".
  251. *@par Third-party framework compatibility
  252. * Compatible with the TensorFlow operator MatrixSetDiag.
  253. */
  254. REG_OP(MatrixSetDiagD)
  255. .INPUT(x, TensorType::BasicType())
  256. .INPUT(diagonal, TensorType::BasicType())
  257. .INPUT(assist, TensorType::BasicType())
  258. .OUTPUT(y, TensorType::BasicType())
  259. .OP_END_FACTORY_REG(MatrixSetDiagD)
  260. /**
  261. *@brief Applies sparse "updates" to individual values or slices in a Variable.
  262. *@par Inputs:
  263. * Three inputs, including:
  264. *@li var: An ND Tensor.
  265. *Must be one of the following types: float16, float32, int8, uint8, double,
  266. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  267. * uint64
  268. *@li indices: An ND Tensor.
  269. *Must be one of the following types: int32, int64
  270. *@li updates: An ND Tensor.
  271. *Must be one of the following types: float16, float32, int8, uint8, double,
  272. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  273. * uint64
  274. *@par Attributes:
  275. *use_locking: An optional bool. Defaults to "False". If "True",
  276. * the operation will be protected by a lock.
  277. *@par Outputs:
  278. *var: A Tensor. Has the same type and format as input "var".
  279. *@par Third-party framework compatibility
  280. * Compatible with the TensorFlow operator ScatterNdUpdate.
  281. */
  282. REG_OP(ScatterNdUpdate)
  283. .INPUT(var, TensorType::BasicType())
  284. .INPUT(indices, TensorType::IndexNumberType())
  285. .INPUT(updates, TensorType::BasicType())
  286. .OUTPUT(var, TensorType::BasicType())
  287. .ATTR(use_locking, Bool, false)
  288. .OP_END_FACTORY_REG(ScatterNdUpdate)
  289. /**
  290. *@brief Applies sparse addition to individual values or slices in a Variable.
  291. *@par Inputs:
  292. * Three inputs, including:
  293. *@li x: An ND Tensor. \n
  294. *Must be one of the following types: float16, float32, bool, int8, uint8
  295. *@li indices: An ND Tensor. \n
  296. *Must be one of the following types: int32
  297. *@li updates: An ND Tensor. \n
  298. *Must be one of the following types: float16, float32, bool, int8, uint8
  299. *@par Outputs:
  300. *y: A Tensor. Has the same type and format as input "x".
  301. *@par Third-party framework compatibility
  302. * Compatible with the TensorFlow operator TensorScatterUpdate.
  303. */
  304. REG_OP(TensorScatterUpdate)
  305. .INPUT(x, TensorType::BasicType())
  306. .INPUT(indices, TensorType::IndexNumberType())
  307. .INPUT(updates, TensorType::BasicType())
  308. .OUTPUT(y, TensorType::BasicType())
  309. .OP_END_FACTORY_REG(TensorScatterUpdate)
  310. /**
  311. *@brief Adds sparse "updates" to a variable reference.
  312. *@par Inputs:
  313. * Three inputs, including:
  314. *@li var: An ND Tensor.
  315. *Must be one of the following types: float16, float32, int32, int8, uint8
  316. *@li indices: An ND Tensor of type int32 or int64.
  317. *@li updates: An Tensor. format:NCHW, NHWC.
  318. *Must be one of the following types: float16, float32, int32, int8, uint8
  319. *@par Attributes:
  320. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  321. * will be protected by a lock.
  322. *@par Outputs:
  323. *var: A Tensor. Has the same type and format as input "var".
  324. *@par Third-party framework compatibility
  325. * Compatible with the TensorFlow operator ScatterAdd.
  326. */
  327. REG_OP(ScatterAdd)
  328. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  329. .INPUT(indices, TensorType::IndexNumberType())
  330. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  331. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  332. .ATTR(use_locking, Bool, false)
  333. .OP_END_FACTORY_REG(ScatterAdd)
  334. /**
  335. *@brief Divides a variable reference by sparse updates.
  336. *@par Inputs:
  337. * Three inputs, including:
  338. *@li var: An ND Tensor.
  339. *Must be one of the following types: float16, float, int32, int8, uint8
  340. *@li indices: An ND Tensor.
  341. *Must be one of the following types: int32
  342. *@li updates: An ND Tensor.
  343. *Must be one of the following types: float16, float, int32, int8, uint8
  344. *@par Attributes:
  345. *@li use_locking: An optional bool. Defaults to "False". If "True",
  346. * the operation will be protected by a lock.
  347. *@par Outputs:
  348. *var: A Tensor. Has the same type and format as input "var".
  349. *@par Third-party framework compatibility
  350. * Compatible with the TensorFlow operator ScatterDiv.
  351. */
  352. REG_OP(ScatterDiv)
  353. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  354. .INPUT(indices, TensorType({DT_INT32}))
  355. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  356. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  357. .ATTR(use_locking, Bool, false)
  358. .OP_END_FACTORY_REG(ScatterDiv)
  359. /**
  360. *@brief Applies sparse addition to individual values or slices in a Variable.
  361. *@par Inputs:
  362. * Three inputs, including:
  363. *@li var: An ND Tensor.
  364. *Must be one of the following types: float16, float, int32, int8, uint8
  365. *@li indices: An ND Tensor.
  366. *Must be one of the following types: int32
  367. *@li updates: An ND Tensor.
  368. *Must be one of the following types: float16, float, int32, int8, uint8
  369. *@par Attributes:
  370. *use_locking: An optional bool. Defaults to "False". If "True",
  371. * the operation will be protected by a lock.
  372. *@par Outputs:
  373. *var: A Tensor. Has the same type and format as input "var".
  374. *@par Third-party framework compatibility
  375. * Compatible with the TensorFlow operator ScatterNdAdd.
  376. */
  377. REG_OP(ScatterNdAdd)
  378. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  379. .INPUT(indices, TensorType::IndexNumberType())
  380. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  381. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  382. .ATTR(use_locking, Bool, false)
  383. .OP_END_FACTORY_REG(ScatterNdAdd)
  384. /**
  385. *@brief Applies sparse addition to individual values or slices in a Variable.
  386. *@par Inputs:
  387. * Three inputs, including:
  388. *@li x: An ND Tensor. \n
  389. *Must be one of the following types: float16, float32, int32, int8, uint8
  390. *@li indices: An ND Tensor. \n
  391. *Must be one of the following types: int32
  392. *@li updates: An ND Tensor. \n
  393. *Must be one of the following types: float16, float32, int32, int8, uint8
  394. *@par Outputs:
  395. *y: A Tensor. Has the same type and format as input "x".
  396. *@par Third-party framework compatibility
  397. * Compatible with the TensorFlow operator TensorScatterAdd.
  398. */
  399. REG_OP(TensorScatterAdd)
  400. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  401. .INPUT(indices, TensorType::IndexNumberType())
  402. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  403. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  404. .OP_END_FACTORY_REG(TensorScatterAdd)
  405. /**
  406. *@brief Applies sparse subtraction to individual values or slices in a Variable.
  407. *@par Inputs:
  408. * Three inputs, including:
  409. *@li var: An ND Tensor.
  410. *Must be one of the following types: float16, float, int32, int8, uint8
  411. *@li indices: An ND Tensor.
  412. *Must be one of the following types: int32, int64
  413. *@li updates: An ND Tensor.
  414. *Must be one of the following types: float16, float, int32, int8, uint8
  415. *@par Attributes:
  416. *use_locking: An optional bool. Defaults to "False". If "True",
  417. * the operation will be protected by a lock.
  418. *@par Outputs:
  419. *var: A Tensor. Has the same type and format as input "var".
  420. *@par Third-party framework compatibility
  421. * Compatible with the TensorFlow operator ScatterNdSub.
  422. */
  423. REG_OP(ScatterNdSub)
  424. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  425. .INPUT(indices, TensorType::IndexNumberType())
  426. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  427. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  428. .ATTR(use_locking, Bool, false)
  429. .OP_END_FACTORY_REG(ScatterNdSub)
  430. /**
  431. *@brief Applies sparse addition to individual values or slices in a Variable.
  432. *@par Inputs:
  433. * Three inputs, including:
  434. *@li x: An ND Tensor. \n
  435. *Must be one of the following types: float16, float32, int32, int8, uint8
  436. *@li indices: An ND Tensor. \n
  437. *Must be one of the following types: int32
  438. *@li updates: An ND Tensor. \n
  439. *Must be one of the following types: float16, float32, int32, int8, uint8
  440. *@par Outputs:
  441. *y: A Tensor. Has the same type and format as input "x".
  442. *@par Third-party framework compatibility
  443. * Compatible with the TensorFlow operator TensorScatterSub.
  444. */
  445. REG_OP(TensorScatterSub)
  446. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  447. .INPUT(indices, TensorType::IndexNumberType())
  448. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  449. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  450. .OP_END_FACTORY_REG(TensorScatterSub)
  451. /**
  452. *@brief Subtracts sparse updates to a variable reference.
  453. *@par Inputs:
  454. * Three inputs, including:
  455. *@li var: An ND Tensor.
  456. *Must be one of the following types: float16, float, int32, int8, uint8
  457. *@li indices: An ND Tensor.
  458. *Must be one of the following types: int32, int64
  459. *@li updates: An ND Tensor.
  460. *Must be one of the following types: float16, float, int32, int8, uint8
  461. *@par Attributes:
  462. *use_locking: An optional bool. Defaults to "False". If "True",
  463. * the operation will be protected by a lock.
  464. *@par Outputs:
  465. *var: A Tensor. Has the same type and format as input "var".
  466. *@par Third-party framework compatibility
  467. * Compatible with the TensorFlow operator ScatterSub.
  468. */
  469. REG_OP(ScatterSub)
  470. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  471. .INPUT(indices, TensorType::IndexNumberType())
  472. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  473. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  474. .ATTR(use_locking, Bool, false)
  475. .OP_END_FACTORY_REG(ScatterSub)
  476. /**
  477. *@brief: Returns the batched diagonal part of a batched tensor with "assist".
  478. *@par Inputs:
  479. * Two inputs, including:
  480. * @li x: A Tensor of type float16, float32, or int32.
  481. * @li assist: A Tensor of the same type as "x".
  482. *@par Outputs:
  483. *y: A Tensor. Has the same type as "x".
  484. *@par Third-party framework compatibility
  485. * Compatible with the TensorFlow operator DiagPart.
  486. */
  487. REG_OP(DiagPartD)
  488. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  489. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  490. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  491. .OP_END_FACTORY_REG(DiagPartD)
  492. /**
  493. *@brief: Returns the batched diagonal part of a batched tensor.
  494. *@par Inputs:
  495. *x: A Tensor. Must be one of the following types:
  496. * float16, float32, int32, int64, double, complex64, complex128.
  497. *@par Outputs:
  498. *y: A Tensor. Has the same type as "x".
  499. *@par Third-party framework compatibility
  500. * Compatible with the TensorFlow operator DiagPart.
  501. */
  502. REG_OP(DiagPart)
  503. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  504. DT_COMPLEX64, DT_COMPLEX128}))
  505. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  506. DT_COMPLEX64, DT_COMPLEX128}))
  507. .OP_END_FACTORY_REG(DiagPart)
  508. /**
  509. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases.
  510. *@par Inputs:
  511. * Four inputs, including:
  512. *@li x: A Tensor of type float16, int8.
  513. *@li w: A weight matrix of type float16, int8.
  514. *@li b: A Tensor of type float16, int32, float32.
  515. *@li offset_w: A Tensor of type int8.
  516. *@par Attributes:
  517. *@li num_output: Reserved.
  518. *@li transpose: A bool, specifying weight whether to transpose, either "true" or "false". Defaults to "false".
  519. *@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1.
  520. * The product of the subsequent dimensions starting form first dimension or the second dimension is "K".
  521. *@li offset_x: Reserved.
  522. *@par Outputs:
  523. *y: The result tensor of type float16, int32, float32.
  524. *@par Third-party framework compatibility
  525. * Compatible with the Caffe operator InnerProduct.
  526. *@par Quantization supported or not
  527. * Yes
  528. */
  529. REG_OP(FullyConnection)
  530. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  531. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  532. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  533. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  534. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  535. .REQUIRED_ATTR(num_output, Int)
  536. .ATTR(transpose, Bool, false)
  537. .ATTR(axis, Int, 1)
  538. .ATTR(offset_x, Int, 0)
  539. .OP_END_FACTORY_REG(FullyConnection)
  540. /**
  541. *@brief Also known as a "fully-connected-compress" layer, computes an inner product with a set of learned weights, and (optionally) adds biases.
  542. *@par Inputs:
  543. * Four inputs, including:
  544. *@li x: A Tensor of type uint8, int8.
  545. *@li w: A weight matrix of type int8, int8.
  546. *@li w: A compress index matrix of type int8, int8.
  547. *@li b: A Tensor of type float16, int32, int32.
  548. *@li offset_w: A Tensor of type int8.i
  549. *@par Attributes:
  550. *@li num_output: Reserved.
  551. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  552. *@li axis: Reserved.
  553. *@li offset_x: Reserved.
  554. *@par Outputs:
  555. *y: The result tensor of type int32.
  556. *@par Third-party framework compatibility
  557. * Compatible with the Caffe operator InnerProduct.
  558. *@par Quantization supported or not
  559. * Yes
  560. */
  561. REG_OP(FullyConnectionCompress)
  562. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  563. .INPUT(w, TensorType({DT_INT8}))
  564. .INPUT(comress_index, TensorType({DT_INT8}))
  565. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  566. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  567. .OUTPUT(y, TensorType({DT_INT32}))
  568. .REQUIRED_ATTR(num_output, Int)
  569. .ATTR(transpose, Bool, false)
  570. .ATTR(axis, Int, 1)
  571. .ATTR(offset_x, Int, 0)
  572. .OP_END_FACTORY_REG(FullyConnectionCompress)
  573. /**
  574. *@brief Computes the confusion matrix from predictions and labels.
  575. *@par Inputs:
  576. * Three inputs, including:
  577. *@li labels: A Tensor. Must be one of the following types: float16, float32,
  578. * int32, int8, uint8.
  579. *@li predictions: A Tensor. Must be one of the following types: float16,
  580. * float32, int32, int8, uint8.
  581. *@li weights: A Tensor. Must be one of the following types: float16, float32,
  582. * int32, int8, uint8.
  583. *@par Attributes:
  584. *@li num_classes: An integer for the shape of the output matrix.
  585. * No default value.
  586. *@li dtype: Data type of the confusion matrix. No default value.
  587. *@par Outputs:
  588. *y: A Tensor. Has the same type and format as input "labels"
  589. *@attention Constraints:
  590. *@li "weights", "labels", and "predictions" are 1D tensors.
  591. *@li The output is with shape (num_classes, num_classes),
  592. * where, 1 <= num_classes <= 4096.
  593. *@see Region()
  594. *@par Third-party framework compatibility
  595. * Compatible with the TensorFlow operator ConfusionMatrix.
  596. */
  597. REG_OP(ConfusionMatrix)
  598. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  599. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  600. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  601. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  602. .REQUIRED_ATTR(num_classes, Int)
  603. .REQUIRED_ATTR(dtype, String)
  604. .OP_END_FACTORY_REG(ConfusionMatrix)
  605. /**
  606. *@brief Multiplies sparse updates into a variable reference.
  607. *@par Inputs:
  608. * Three inputs, including:
  609. *@li var: An ND Tensor.
  610. *Must be one of the following types: float16, float, int32, int8, uint8
  611. *@li indices: An ND Tensor.
  612. *Must be one of the following types: int32
  613. *@li updates: An ND Tensor.
  614. *Must be one of the following types: float16, float, int32, int8, uint8
  615. *@par Attributes:
  616. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  617. * will be protected by a lock.
  618. *@par Outputs:
  619. *var: A Tensor. Has the same type and format as input "var".
  620. *@par Third-party framework compatibility
  621. * Compatible with the TensorFlow operator ScatterMul.
  622. */
  623. REG_OP(ScatterMul)
  624. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  625. .INPUT(indices, TensorType({DT_INT32}))
  626. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  627. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  628. .ATTR(use_locking, Bool, false)
  629. .OP_END_FACTORY_REG(ScatterMul)
  630. /**
  631. *@brief Reduces sparse updates into a variable reference using
  632. * the "min" operation.
  633. *@par Inputs:
  634. * Three inputs, including:
  635. *@li var: An ND Tensor.
  636. *Must be one of the following types: float16, float, int32
  637. *@li indices: An ND Tensor.
  638. *Must be one of the following types: int32
  639. *@li updates: An ND Tensor.
  640. *Must be one of the following types: float16, float, int32
  641. *@par Attributes:
  642. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  643. * will be protected by a lock.
  644. *@par Outputs:
  645. *var: A Tensor. Has the same type and format as input "var".
  646. *@par Third-party framework compatibility
  647. * Compatible with the TensorFlow operator ScatterMin.
  648. */
  649. REG_OP(ScatterMin)
  650. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  651. .INPUT(indices, TensorType({DT_INT32}))
  652. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  653. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  654. .ATTR(use_locking, Bool, false)
  655. .OP_END_FACTORY_REG(ScatterMin)
  656. /**
  657. *@brief Reduces sparse updates into a variable reference using the "max" operation.
  658. *@par Inputs:
  659. * Three inputs, including:
  660. *@li var: An ND Tensor.
  661. *Must be one of the following types: float16, float, int32
  662. *@li indices: An NCHW, NHWC, or ND Tensor.
  663. *Must be one of the following types: int32
  664. *@li updates: An NCHW, NHWC, or ND Tensor.
  665. *Must be one of the following types: float16, float, int32
  666. *@par Attributes:
  667. *use_locking: An optional bool. Defaults to "False".
  668. * If "True", the operation will be protected by a lock.
  669. *@par Outputs:
  670. *var: A Tensor. Has the same type and format as input "var".
  671. *@par Third-party framework compatibility
  672. * Compatible with the TensorFlow operator ScatterMax.
  673. */
  674. REG_OP(ScatterMax)
  675. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  676. .INPUT(indices, TensorType({DT_INT32}))
  677. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  678. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  679. .ATTR(use_locking, Bool, false)
  680. .OP_END_FACTORY_REG(ScatterMax)
  681. /**
  682. *@brief Applies sparse updates to a variable reference.
  683. *@par Inputs:
  684. * Three inputs, including:
  685. *@li var: An ND Tensor.
  686. *Must be one of the following types: float16, float, int32, int8, uint8
  687. *@li indices: An ND Tensor.
  688. *Must be one of the following types: int32
  689. *@li updates: An ND Tensor.
  690. *Must be one of the following types: float16, float, int32, int8, uint8
  691. *@par Attributes:
  692. *use_locking: An optional bool. Defaults to "False". If "True",
  693. * the operation will be protected by a lock.
  694. *@par Outputs:
  695. *var: A Tensor. Has the same type and format as input "var".
  696. *@par Third-party framework compatibility
  697. * Compatible with the TensorFlow operator ScatterUpdate.
  698. */
  699. REG_OP(ScatterUpdate)
  700. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  701. .INPUT(indices, TensorType({DT_INT32}))
  702. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  703. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  704. .ATTR(use_locking, Bool, false)
  705. .OP_END_FACTORY_REG(ScatterUpdate)
  706. /**
  707. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input`.
  708. *@par Inputs:
  709. * Three inputs, including:
  710. *@li input: Rank `r` tensor where `r >= 2`. \n
  711. *@li k: \n
  712. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  713. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  714. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  715. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  716. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  717. *@par Outputs:
  718. *diagonal: The extracted diagonal(s).
  719. *@par Third-party framework compatibility
  720. * Compatible with the TensorFlow operator ScatterUpdate.
  721. */
  722. REG_OP(MatrixDiagPartV2)
  723. .INPUT(input, TensorType::BasicType())
  724. .INPUT(k, TensorType({DT_INT32}))
  725. .INPUT(padding_value, TensorType::BasicType())
  726. .OUTPUT(diagonal, TensorType::BasicType())
  727. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  728. /**
  729. *@brief Returns a batched matrix tensor with new batched diagonal values.
  730. *@par Inputs:
  731. * Three inputs, including:
  732. *@li input: "Rank `r+1`, where `r >= 1`. \n
  733. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  734. *@li k:
  735. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  736. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  737. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  738. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  739. *@par Outputs:
  740. *output: Rank `r+1`, with `output.shape = input.shape`.
  741. *@par Third-party framework compatibility
  742. * Compatible with the TensorFlow operator ScatterUpdate.
  743. */
  744. REG_OP(MatrixSetDiagV2)
  745. .INPUT(input, TensorType::BasicType())
  746. .INPUT(diagonal, TensorType::BasicType())
  747. .INPUT(k, TensorType({DT_INT32}))
  748. .OUTPUT(output, TensorType::BasicType())
  749. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  750. /**
  751. *@brief Returns a batched diagonal tensor with given batched diagonal values.
  752. *@par Inputs:
  753. * Five inputs, including:
  754. *@li diagonal: Rank `r`, where `r >= 1` \n
  755. *@li k:
  756. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  757. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  758. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  759. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  760. *@li num_rows:
  761. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  762. *the output matrix is a square matrix and infers the matrix size from k and the \n
  763. *innermost dimension of `diagonal`. \n
  764. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  765. *The number of columns of the output matrix. If it is not provided, the op \n
  766. *assumes the output matrix is a square matrix and infers the matrix size from \n
  767. *k and the innermost dimension of `diagonal`. \n
  768. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  769. *@par Outputs:
  770. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise.
  771. *@par Third-party framework compatibility
  772. * Compatible with the TensorFlow operator ScatterUpdate.
  773. */
  774. REG_OP(MatrixDiagV2)
  775. .INPUT(diagonal, TensorType::BasicType())
  776. .INPUT(k, TensorType({DT_INT32}))
  777. .INPUT(num_rows, TensorType({DT_INT32}))
  778. .INPUT(num_cols, TensorType({DT_INT32}))
  779. .INPUT(padding_value, TensorType::BasicType())
  780. .OUTPUT(output, TensorType::BasicType())
  781. .OP_END_FACTORY_REG(MatrixDiagV2)
  782. } // namespace ge
  783. #endif // GE_OP_MATRIX_CALCULATION_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示