You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_MATRIX_CALCULATION_OPS_H
  17. #define GE_OP_MATRIX_CALCULATION_OPS_H
  18. #include "../graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  22. *@par Inputs:
  23. *Two inputs, including:
  24. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  25. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  26. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  27. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  28. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC].
  30. *@par Attributes:
  31. *@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  32. *@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M].
  33. *@par Outputs:
  34. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  35. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  36. */
  37. REG_OP(MatMul)
  38. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  39. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  40. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  41. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  42. .ATTR(transpose_x1, Bool, false)
  43. .ATTR(transpose_x2, Bool, false)
  44. .OP_END_FACTORY_REG(MatMul)
  45. /**
  46. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c.
  47. *@par Inputs:
  48. *Five inputs, including:
  49. *@li a: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. Has format [FRACTAL_NZ].
  50. *@li b: A matrix Tensor. 4D. Must be one of the following types:\n float16, int8. When type is int8, has format [FRACTAL_Z], \n otherwise has format [FRACTAL_NZ].
  51. *@li c: A matrix Tensor. 2D or higher. Must be one of the following types: \n float16, int32, float32. When type is int32, has format [ND], \n otherwise has format [FRACTAL_NZ].
  52. *@li alpha: A 1D Tensor. The shape of alpha is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  53. *@li beta: A 1D Tensor. The shape of beta is [1].\n Must be one of the following types: float16, int32, float32. Has format [ND].
  54. *@par Attributes:
  55. *Two attributes, including:
  56. *@li transpose_a: Optional. A bool.\n If True, changes the shape of "a" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  57. *@li transpose_b: Optional. A bool.\n If True, changes the shape of "b" from [M, K] to [K, M].\n Reserved parameters, not used for now.
  58. *@par Outputs:
  59. *@out: The result matrix Tensor. 4D. Must be one of the following types:\n float16, float32, int32. Has format [FRACTAL_NZ].
  60. */
  61. REG_OP(Gemm)
  62. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  63. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  64. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  65. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  66. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  67. .OUTPUT(out, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  68. .ATTR(transpose_a, Bool, false)
  69. .ATTR(transpose_b, Bool, false)
  70. .OP_END_FACTORY_REG(Gemm)
  71. /**
  72. *@brief Multiplies matrix "a" by matrix "b", producing "a * b".
  73. *@par Inputs:
  74. *Three inputs, including:
  75. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  76. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  77. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  78. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  79. *@par Attributes:
  80. *@li adj_x: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  81. *@li adj_y: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M].
  82. *@par Outputs:
  83. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  84. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2".
  85. */
  86. REG_OP(BatchMatMul)
  87. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  88. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  89. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  90. .ATTR(adj_x1, Bool, false)
  91. .ATTR(adj_x2, Bool, false)
  92. .OP_END_FACTORY_REG(BatchMatMul)
  93. REG_OP(MeanCCE)
  94. .INPUT(x, TensorType::ALL())
  95. .INPUT(indices, TensorType::ALL())
  96. .OUTPUT(y, TensorType::ALL())
  97. .ATTR(keep_dims, Bool, false)
  98. .ATTR(value1, ListInt, {})
  99. .ATTR(mode, Int, 3) // 0:max pooling or 1:avg pooling
  100. .ATTR(pad_mode, Int, 0)
  101. .ATTR(global_pooling, Bool, true)
  102. .ATTR(window, ListInt, {1,1}) // kernel size
  103. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  104. .ATTR(stride, ListInt, {1,1}) // stride size
  105. .ATTR(ceil_mode, Int, 0)
  106. .ATTR(data_mode, Int, 1)
  107. .ATTR(nan_opt, Int, 0)
  108. .ATTR(fomart, Int, 0)
  109. .OP_END_FACTORY_REG(MeanCCE)
  110. REG_OP(MeanGrad)
  111. .INPUT(x, TensorType::ALL())
  112. .OUTPUT(y, TensorType::ALL())
  113. .ATTR(mode, Int, 1) // 0:max pooling or 1:avg pooling
  114. .ATTR(pad_mode, Int, 0)
  115. .ATTR(global_pooling, Bool, false)
  116. .ATTR(window, ListInt, {1,1}) // kernel size
  117. .ATTR(pad, ListInt, {0,0,0,0}) // pad size
  118. .ATTR(stride, ListInt, {1,1}) // stride size
  119. .ATTR(ceil_mode, Int, 0)
  120. .ATTR(data_mode, Int, 1)
  121. .ATTR(nan_opt, Int, 0)
  122. .ATTR(mean_grad_output_shape_value, ListInt, {1,1,1,1})
  123. .ATTR(mean_grad_output_shape_format, Int, 1) //must be NHWC
  124. .OP_END_FACTORY_REG(MeanGrad)
  125. REG_OP(MatMulCCE)
  126. .INPUT(x1, TensorType({DT_FLOAT}))
  127. .INPUT(x2, TensorType({DT_FLOAT}))
  128. .OPTIONAL_INPUT(x3, TensorType({DT_FLOAT}))
  129. .OUTPUT(y, TensorType({DT_FLOAT}))
  130. .ATTR(transpose_a, Bool, false)
  131. .ATTR(transpose_b, Bool, false)
  132. .ATTR(has_bias, Bool, false)
  133. .OP_END_FACTORY_REG(MatMulCCE)
  134. /**
  135. *@brief Computes half the L2 norm of a tensor without the sqrt.
  136. *@par Inputs:
  137. * x: A Tensor.
  138. * TensorType::FloatingDataType().
  139. *@par Outputs:
  140. *y: A Tensor. Has the same type as "x".
  141. */
  142. REG_OP(L2Loss)
  143. .INPUT(x, TensorType::FloatingDataType())
  144. .OUTPUT(y, TensorType::FloatingDataType())
  145. .OP_END_FACTORY_REG(L2Loss)
  146. /**
  147. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  148. *@par Inputs:
  149. *x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  150. *@par Outputs:
  151. *y: A Tensor. Has the same type as "x".
  152. */
  153. REG_OP(MatrixDiag)
  154. .INPUT(x, TensorType::BasicType())
  155. .OUTPUT(y, TensorType::BasicType())
  156. .OP_END_FACTORY_REG(MatrixDiag)
  157. /**
  158. *@brief: Returns a batched diagonal tensor with a given batched diagonal values.
  159. *@par Inputs:
  160. * Two inputs, including:
  161. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  162. *@li assist: A Tensor of the same type as "x".
  163. *@par Outputs:
  164. *y: A Tensor. Has the same type as "x".
  165. */
  166. REG_OP(MatrixDiagD)
  167. .INPUT(x, TensorType::BasicType())
  168. .INPUT(assist, TensorType::BasicType())
  169. .OUTPUT(y, TensorType::BasicType())
  170. .OP_END_FACTORY_REG(MatrixDiagD)
  171. /**
  172. *@brief: Returns the batched diagonal part of a batched tensor.
  173. *@par Inputs:
  174. *x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  175. *@par Outputs:
  176. *y: A Tensor. Has the same type as "x".
  177. */
  178. REG_OP(MatrixDiagPart)
  179. .INPUT(x, TensorType::BasicType())
  180. .OUTPUT(y, TensorType::BasicType())
  181. .OP_END_FACTORY_REG(MatrixDiagPart)
  182. /**
  183. *@brief: Returns the batched diagonal part of a batched tensor.
  184. *@par Inputs:
  185. * Two inputs, including:
  186. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  187. *@li assist: A Tensor of the same type as "x".
  188. *@par Outputs:
  189. *y: A Tensor. Has the same type as "x".
  190. */
  191. REG_OP(MatrixDiagPartD)
  192. .INPUT(x, TensorType::BasicType())
  193. .INPUT(assist, TensorType::BasicType())
  194. .OUTPUT(y, TensorType::BasicType())
  195. .OP_END_FACTORY_REG(MatrixDiagPartD)
  196. /**
  197. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  198. *@par Inputs:
  199. * Two inputs, including:
  200. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  201. *@li diagonal: A Tensor of the same type as "x".
  202. *@par Outputs:
  203. *y: A Tensor. Has the same type as "x".
  204. */
  205. REG_OP(MatrixSetDiag)
  206. .INPUT(x, TensorType::BasicType())
  207. .INPUT(diagonal, TensorType::BasicType())
  208. .OUTPUT(y, TensorType::BasicType())
  209. .OP_END_FACTORY_REG(MatrixSetDiag)
  210. /**
  211. *@brief: Returns a batched matrix tensor with new batched diagonal values.
  212. *@par Inputs:
  213. * Three inputs, including:
  214. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  215. *@li diagonal: A Tensor of the same type as "x".
  216. *@li assist: A Tensor of the same type as "x".
  217. *@par Outputs:
  218. *y: A Tensor. Has the same type as "x".
  219. */
  220. REG_OP(MatrixSetDiagD)
  221. .INPUT(x, TensorType::BasicType())
  222. .INPUT(diagonal, TensorType::BasicType())
  223. .INPUT(assist, TensorType::BasicType())
  224. .OUTPUT(y, TensorType::BasicType())
  225. .OP_END_FACTORY_REG(MatrixSetDiagD)
  226. /**
  227. *@brief Applies sparse "updates" to individual values or slices in a Variable.
  228. *@par Inputs:
  229. * Three inputs, including:
  230. *@li var: An ND Tensor. \n
  231. *Must be one of the following types: float16, float32, int8, uint8, bool
  232. *@li indices: An ND Tensor. \n
  233. *Must be one of the following types: int32
  234. *@li updates: An ND Tensor. \n
  235. *Must be one of the following types: float16, float32, int8, uint8, bool
  236. *@par Attributes:
  237. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  238. *@par Outputs:
  239. *var: A Tensor. Has the same type and format as input "var".
  240. */
  241. REG_OP(ScatterNdUpdate)
  242. .INPUT(var, TensorType::BasicType())
  243. .INPUT(indices, TensorType::IndexNumberType())
  244. .INPUT(updates, TensorType::BasicType())
  245. .OUTPUT(var, TensorType::BasicType())
  246. .ATTR(use_locking, Bool, false)
  247. .OP_END_FACTORY_REG(ScatterNdUpdate)
  248. /**
  249. *@brief Adds sparse "updates" to a variable reference.
  250. *@par Inputs:
  251. * Three inputs, including:
  252. *@li var: An ND Tensor. \n
  253. *Must be one of the following types: float16, float32, int32, int8, uint8
  254. *@li indices: An ND Tensor of type int32.
  255. *@li updates: An ND Tensor. \n
  256. *Must be one of the following types: float16, float32, int32, int8, uint8
  257. *@par Attributes:
  258. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  259. *@par Outputs:
  260. *var: A Tensor. Has the same type and format as input "var".
  261. */
  262. REG_OP(ScatterAdd)
  263. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  264. .INPUT(indices, TensorType::IndexNumberType())
  265. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  266. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  267. .ATTR(use_locking, Bool, false)
  268. .OP_END_FACTORY_REG(ScatterAdd)
  269. /**
  270. *@brief Divides a variable reference by sparse updates.
  271. *@par Inputs:
  272. * Three inputs, including:
  273. *@li var: An NCHW, NHWC, or ND Tensor. \n
  274. *Must be one of the following types: float16, float32, int32, int8, uint8
  275. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  276. *Must be one of the following types: int32
  277. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  278. *Must be one of the following types: float16, float32, int32, int8, uint8
  279. *@par Attributes:
  280. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  281. *@li isRef: An optional bool. Defaults to "True"
  282. *@par Outputs:
  283. *var: A Tensor. Has the same type and format as input "var".
  284. */
  285. REG_OP(ScatterDiv)
  286. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  287. .INPUT(indices, TensorType({DT_INT32}))
  288. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  289. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  290. .ATTR(use_locking, Bool, false)
  291. .OP_END_FACTORY_REG(ScatterDiv)
  292. /**
  293. *@brief Applies sparse addition to individual values or slices in a Variable.
  294. *@par Inputs:
  295. * Three inputs, including:
  296. *@li var: An ND Tensor. \n
  297. *Must be one of the following types: float16, float32, int32, int8, uint8
  298. *@li indices: An ND Tensor. \n
  299. *Must be one of the following types: int32
  300. *@li updates: An ND Tensor. \n
  301. *Must be one of the following types: float16, float32, int32, int8, uint8
  302. *@par Attributes:
  303. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  304. *@par Outputs:
  305. *var: A Tensor. Has the same type and format as input "var".
  306. */
  307. REG_OP(ScatterNdAdd)
  308. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  309. .INPUT(indices, TensorType::IndexNumberType())
  310. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  311. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  312. .ATTR(use_locking, Bool, false)
  313. .OP_END_FACTORY_REG(ScatterNdAdd)
  314. /**
  315. *@brief Applies sparse subtraction to individual values or slices in a Variable.
  316. *@par Inputs:
  317. * Three inputs, including:
  318. *@li var: An ND Tensor. \n
  319. *Must be one of the following types: float16, float32, int32, int8, uint8
  320. *@li indices: An ND Tensor. \n
  321. *Must be one of the following types: int32
  322. *@li updates: An ND Tensor. \n
  323. *Must be one of the following types: float16, float32, int32, int8, uint8
  324. *@par Attributes:
  325. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  326. *@par Outputs:
  327. *var: A Tensor. Has the same type and format as input "var".
  328. */
  329. REG_OP(ScatterNdSub)
  330. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  331. .INPUT(indices, TensorType::IndexNumberType())
  332. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  333. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  334. .ATTR(use_locking, Bool, false)
  335. .OP_END_FACTORY_REG(ScatterNdSub)
  336. /**
  337. *@brief Subtracts sparse updates to a variable reference.
  338. *@par Inputs:
  339. * Three inputs, including:
  340. *@li var: An ND Tensor. \n
  341. *Must be one of the following types: float16, float32, int32, int8, uint8
  342. *@li indices: An ND Tensor. \n
  343. *Must be one of the following types: int32
  344. *@li updates: An ND Tensor. \n
  345. *Must be one of the following types: float16, float32, int32, int8, uint8
  346. *@par Attributes:
  347. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  348. *@par Outputs:
  349. *var: A Tensor. Has the same type and format as input "var".
  350. */
  351. REG_OP(ScatterSub)
  352. .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  353. .INPUT(indices, TensorType::IndexNumberType())
  354. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  355. .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  356. .ATTR(use_locking, Bool, false)
  357. .OP_END_FACTORY_REG(ScatterSub)
  358. /**
  359. *@brief: Returns the batched diagonal part of a batched tensor with "assist".
  360. *@par Inputs:
  361. * Two inputs, including:
  362. * @li x: A Tensor of type float16, float32, or int32.
  363. * @li assist: A Tensor of the same type as "x".
  364. *@par Outputs:
  365. *y: A Tensor. Has the same type as "x".
  366. */
  367. REG_OP(DiagPartD)
  368. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  369. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  370. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  371. .OP_END_FACTORY_REG(DiagPartD)
  372. /**
  373. *@brief: Returns the batched diagonal part of a batched tensor.
  374. *@par Inputs:\n
  375. *x: A Tensor. Must be one of the following types: float16, float32, int32, int64, double, complex64, complex128.
  376. *@par Outputs:
  377. *y: A Tensor. Has the same type as "x".
  378. */
  379. REG_OP(DiagPart)
  380. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  381. DT_COMPLEX64, DT_COMPLEX128}))
  382. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  383. DT_COMPLEX64, DT_COMPLEX128}))
  384. .OP_END_FACTORY_REG(DiagPart)
  385. /**
  386. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases.
  387. *@par Inputs:
  388. * Two inputs, including:
  389. *@li x: A Tensor of type float16, int8.
  390. *@li w: A weight matrix of type float16, int8.
  391. *@li b: A Tensor of type float16, int32.
  392. *@li offset_w: A Tensor of type int8.
  393. *@par Attributes:
  394. *@li num_output: Reserved.
  395. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  396. *@li bias_term: A bool, specifying whether to learn and apply a set of additive biases to the filter outputs, either "true" or "false". Defaults to "true".
  397. *@li axis: only support axis is 1. Defaults to "1".
  398. *@li offset_a: A type of Int, Defaults to "1".
  399. *@par Outputs:
  400. *y: The result tensor of type float16, int8.
  401. */
  402. REG_OP(InnerProduct)
  403. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  404. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  405. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32}))
  406. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  407. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32}))
  408. .REQUIRED_ATTR(num_output, Int)
  409. .ATTR(transpose, Bool, false)
  410. .ATTR(bias_term, Bool, true)
  411. .ATTR(axis, Int, 1)
  412. .ATTR(offset_a, Int, 0)
  413. .OP_END_FACTORY_REG(InnerProduct)
  414. /**
  415. *@brief Computes the confusion matrix from predictions and labels.
  416. *@par Inputs:
  417. * Three inputs, including:
  418. *@li labels: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  419. *@li predictions: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  420. *@li weights: A Tensor. Must be one of the following types: float16, float32, int32, int8.
  421. *@par Attributes:
  422. *@li num_classes: An integer for the shape of the output matrix. No default value.
  423. *@li dtype: Data type of the confusion matrix. No default value.
  424. *@par Outputs:
  425. *y: A Tensor. Has the same type and format as input "labels"
  426. *@attention Constraints:
  427. *@li "weights", "labels", and "predictions" are 1D tensors.
  428. *@li The output is with shape (num_classes, num_classes), where, 1 <= num_classes <= 4096.
  429. *@see Region()
  430. */
  431. REG_OP(ConfusionMatrix)
  432. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  433. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  434. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  435. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  436. .REQUIRED_ATTR(num_classes, Int)
  437. .REQUIRED_ATTR(dtype, String)
  438. .OP_END_FACTORY_REG(ConfusionMatrix)
  439. /**
  440. *@brief Multiplies sparse updates into a variable reference.
  441. *@par Inputs:
  442. * Three inputs, including:
  443. *@li var: An NCHW, NHWC, or ND Tensor. \n
  444. *Must be one of the following types: float16, float32, int32, int8, uint8
  445. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  446. *Must be one of the following types: int32
  447. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  448. *Must be one of the following types: float16, float32, int32, int8, uint8
  449. *@par Attributes:
  450. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  451. *@par Outputs:
  452. *var: A Tensor. Has the same type and format as input "var".
  453. */
  454. REG_OP(ScatterMul)
  455. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  456. .INPUT(indices, TensorType({DT_INT32}))
  457. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  458. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  459. .ATTR(use_locking, Bool, false)
  460. .OP_END_FACTORY_REG(ScatterMul)
  461. /**
  462. *@brief Reduces sparse updates into a variable reference using the "min" operation.
  463. *@par Inputs:
  464. * Three inputs, including:
  465. *@li var: An NCHW, NHWC, or ND Tensor. \n
  466. *Must be one of the following types: float16, float32, int32
  467. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  468. *Must be one of the following types: int32
  469. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  470. *Must be one of the following types: float16, float32, int32
  471. *@par Attributes:
  472. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  473. *@par Outputs:
  474. *var: A Tensor. Has the same type and format as input "var".
  475. */
  476. REG_OP(ScatterMin)
  477. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  478. .INPUT(indices, TensorType({DT_INT32}))
  479. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  480. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  481. .ATTR(use_locking, Bool, false)
  482. .OP_END_FACTORY_REG(ScatterMin)
  483. /**
  484. *@brief Reduces sparse updates into a variable reference using the "max" operation.
  485. *@par Inputs:
  486. * Three inputs, including:
  487. *@li var: An NCHW, NHWC, or ND Tensor. \n
  488. *Must be one of the following types: float16, float32, int32
  489. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  490. *Must be one of the following types: int32
  491. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  492. *Must be one of the following types: float16, float32, int32
  493. *@par Attributes:
  494. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  495. *@par Outputs:
  496. *var: A Tensor. Has the same type and format as input "var".
  497. */
  498. REG_OP(ScatterMax)
  499. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  500. .INPUT(indices, TensorType({DT_INT32}))
  501. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  502. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32}))
  503. .ATTR(use_locking, Bool, false)
  504. .OP_END_FACTORY_REG(ScatterMax)
  505. /**
  506. *@brief Applies sparse updates to a variable reference.
  507. *@par Inputs:
  508. * Three inputs, including:
  509. *@li var: An NCHW, NHWC, or ND Tensor. \n
  510. *Must be one of the following types: float16, float32, int32, int8, uint8
  511. *@li indices: An NCHW, NHWC, or ND Tensor. \n
  512. *Must be one of the following types: int32
  513. *@li updates: An NCHW, NHWC, or ND Tensor. \n
  514. *Must be one of the following types: float16, float32, int32, int8, uint8
  515. *@par Attributes:
  516. *use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  517. *@par Outputs:
  518. *var: A Tensor. Has the same type and format as input "var".
  519. */
  520. REG_OP(ScatterUpdate)
  521. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  522. .INPUT(indices, TensorType({DT_INT32}))
  523. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  524. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT8,DT_UINT8}))
  525. .ATTR(use_locking, Bool, false)
  526. .OP_END_FACTORY_REG(ScatterUpdate)
  527. } // namespace ge
  528. #endif // GE_OP_MATRIX_CALCULATION_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示