You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops.h 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_MATH_OPS_H_
  17. #define GE_OP_MATH_OPS_H_
  18. #include "graph/operator_reg.h"
  19. #include "graph/operator.h"
  20. namespace ge {
  21. /**
  22. *@brief Computes the output as (shift + scale * x) ^ power.
  23. *@par Inputs:
  24. * x: A Tensor of type float16 or float32.
  25. *@par Attributes:
  26. *@li power: Optional. Defaults to 1.0.
  27. *@li scale: Optional. Defaults to 1.0.
  28. *@li shift: Optional. Defaults to 0.0.
  29. *@par Outputs:
  30. * y: A Tensor. Has the same type and shape as "x".
  31. */
  32. REG_OP(Power)
  33. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT}))
  34. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  35. .ATTR(power, Float, 1.0)
  36. .ATTR(scale, Float, 1.0)
  37. .ATTR(shift, Float, 0.0)
  38. .OP_END_FACTORY_REG(Power);
  39. /**
  40. *@brief Compute the lower regularized incomplete Gamma function P(a, x).
  41. *@par Inputs:
  42. *The input a and x must have the same type. Inputs include: \n
  43. *@li a:A Tensor. Must be one of the following types: float, double.
  44. *@li x:A Tensor. Must have the same type as a.
  45. *@par Outputs:
  46. *z:A Tensor. Has the same type as a.
  47. */
  48. REG_OP(Igamma)
  49. .INPUT(a, TensorType({DT_FLOAT, DT_DOUBLE}))
  50. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  51. .OUTPUT(z, TensorType({DT_FLOAT, DT_DOUBLE}))
  52. .OP_END_FACTORY_REG(Igamma)
  53. /**
  54. *@brief Compute the upper regularized incomplete Gamma function Q(a, x).
  55. *@par Inputs:
  56. *The input a and x must have the same type. Inputs include: \n
  57. *@li a:A Tensor. Must be one of the following types: float, float64.
  58. *@li x:A Tensor. Must have the same type as a.
  59. *@par Outputs:
  60. *z:A Tensor. Has the same type as a.
  61. */
  62. REG_OP(Igammac)
  63. .INPUT(a, TensorType({DT_FLOAT, DT_DOUBLE}))
  64. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  65. .OUTPUT(z, TensorType({DT_FLOAT, DT_DOUBLE}))
  66. .OP_END_FACTORY_REG(Igammac)
  67. /**
  68. *@brief Compare values of input to threshold and pack resulting bits into \n
  69. a uint8.
  70. *@par Inputs:
  71. *The input size must be a non-negative int32 scalar Tensor. Inputs include: \n
  72. *@li input:Values to compare against threshold and bitpack.
  73. *@li threshold:Threshold to compare against.
  74. *@par Outputs:
  75. *y:The bitpacked comparisons.
  76. *@attention Constraints: \n
  77. *Currently, the innermost dimension of the tensor must be divisible by 8. \n
  78. */
  79. REG_OP(CompareAndBitpack)
  80. .INPUT(x, TensorType({ DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT8, \
  81. DT_INT16, DT_INT32, DT_INT64, DT_BOOL }))
  82. .INPUT(threshold, TensorType({ DT_FLOAT, DT_FLOAT16, DT_DOUBLE, \
  83. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_BOOL }))
  84. .OUTPUT(y, TensorType(DT_UINT8))
  85. .OP_END_FACTORY_REG(CompareAndBitpack)
  86. /**
  87. *@brief Counts the number of occurrences of each value in an integer array. \n
  88. Outputs a vector with length size and the same dtype as weights. If weights \n
  89. are empty, then index i stores the number of times the value i is counted in \n
  90. arr. If weights are non-empty, then index i stores the sum of the value in \n
  91. weights at each index.
  92. *@par Inputs:
  93. *The input size must be a non-negative int32 scalar Tensor. Inputs include: \n
  94. *@li array:int32 Tensor.
  95. *@li size:non-negative int32 scalar Tensor.
  96. *@li weights: is an int32, int64, float32, or double Tensor with the same \n
  97. shape as arr, or a length-0 Tensor, in which case it acts as all weights \n
  98. equal to 1.
  99. *@par Outputs:
  100. *bins:1D Tensor with length equal to size. The counts or summed weights for \n
  101. each value in the range [0, size).
  102. */
  103. REG_OP(Bincount)
  104. .INPUT(array, TensorType(DT_INT32))
  105. .INPUT(size, TensorType(DT_INT32))
  106. .INPUT(weights, TensorType({ DT_FLOAT, DT_INT32, DT_INT64, DT_DOUBLE }))
  107. .OUTPUT(bins, TensorType({ DT_FLOAT, DT_INT32, DT_INT64, DT_DOUBLE }))
  108. .OP_END_FACTORY_REG(Bincount)
  109. /**
  110. *@brief Compute the regularized incomplete beta integral.
  111. *@par Inputs:
  112. *The input b and x must have the same types as a. Inputs include: \n
  113. *@li a:A Tensor. Must be one of the following types: float32, double.
  114. *@li b:A Tensor. Must have the same type as a.
  115. *@li x:A Tensor. Must have the same type as a.
  116. *@par Outputs:
  117. *z:A Tensor. Has the same type as a.
  118. */
  119. REG_OP(Betainc)
  120. .INPUT(a, TensorType({DT_DOUBLE, DT_FLOAT}))
  121. .INPUT(b, TensorType({DT_DOUBLE, DT_FLOAT}))
  122. .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
  123. .OUTPUT(z, TensorType({DT_DOUBLE, DT_FLOAT}))
  124. .OP_END_FACTORY_REG(Betainc)
  125. /**
  126. *@brief Compute the Hurwitz zeta function
  127. *@par Inputs:
  128. *The input q must be the same type as x. Inputs include: \n
  129. *@li x:A Tensor. Must be one of the following types: float32, double.
  130. *@li q:A Tensor. Must have the same type as x.
  131. *@par Outputs:
  132. *z:A Tensor. Has the same type as x.
  133. *@attention Constraints: \n
  134. *The implementation for Zeta on Ascend uses ai cpu, with bad performance. \n
  135. */
  136. REG_OP(Zeta)
  137. .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
  138. .INPUT(q, TensorType({DT_DOUBLE, DT_FLOAT}))
  139. .OUTPUT(z, TensorType({DT_DOUBLE, DT_FLOAT}))
  140. .OP_END_FACTORY_REG(Zeta)
  141. /**
  142. *@brief Bucketizes 'input' based on 'boundaries'. For example, if the inputs \n
  143. are boundaries = [0, 10, 100] input = [[-5, 10000] [150, 10] [5, 100]] then \n
  144. the output will be output = [[0, 3] [3, 2] [1, 3]]
  145. *@par Inputs:
  146. *The dtype of input x must be int or float. Inputs include: \n
  147. *x:Any shape of Tensor contains with int or float type.
  148. *@par Attributes:
  149. *boundaries:A sorted list of floats gives the boundary of the buckets.
  150. *@par Outputs:
  151. *y:Same shape with 'input', each value of input replaced with bucket index.
  152. */
  153. REG_OP(Bucketize)
  154. .INPUT(x, TensorType({DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT}))
  155. .OUTPUT(y, TensorType({DT_INT32}))
  156. .REQUIRED_ATTR(boundaries, ListFloat)
  157. .OP_END_FACTORY_REG(Bucketize)
  158. /**
  159. *@brief Computes the sum along sparse segments of a tensor.
  160. *@par Inputs:
  161. *The input indices and segment_ids must have same rank. Inputs include: \n
  162. *@li x:A Tensor. Must be one of the following types: float, double, int32, \n
  163. uint8, int16, int8, int64, uint16, uint32, uint64.
  164. *@li indices: A Tensor. Must be one of the following types: int32, int64. \n
  165. A 1-D tensor. Has same rank as segment_ids.
  166. *@li segment_ids: A Tensor of type int32. A 1-D tensor. Values should be \n
  167. sorted and can be repeated.
  168. *@par Outputs:
  169. *y:A Tensor. Has the same type as x.
  170. */
  171. REG_OP(SparseSegmentSum)
  172. .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16,
  173. DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
  174. .INPUT(indices, TensorType({DT_INT32}))
  175. .INPUT(segment_ids, TensorType({DT_INT32}))
  176. .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16,
  177. DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
  178. .OP_END_FACTORY_REG(SparseSegmentSum)
  179. /**
  180. *@brief Computes the mean along sparse segments of a tensor.
  181. *@par Inputs:
  182. *The input indices and segment_ids must have same rank. Inputs include: \n
  183. *@li x: A Tensor. Must be one of the following types: float, double.
  184. *@li indices: A Tensor. Must be one of the following types: int32, int64. \n
  185. A 1-D tensor. Has same rank as segment_ids.
  186. *@li segment_ids: A Tensor of type int32. A 1-D tensor. Values should be \n
  187. sorted and can be repeated.
  188. *@par Outputs:
  189. *y:A Tensor. Has the same type as x.
  190. */
  191. REG_OP(SparseSegmentMean)
  192. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  193. .INPUT(indices, TensorType({DT_INT32}))
  194. .INPUT(segment_ids, TensorType({DT_INT32}))
  195. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  196. .OP_END_FACTORY_REG(SparseSegmentMean)
  197. /**
  198. *@brief Computes gradients for SparseSegmentMean.
  199. *@par Inputs:
  200. *The input grad must have be type float or double. Inputs include: \n
  201. *@li grad: A Tensor. Must be one of the following types: float, double. \n
  202. gradient propagated to the SparseSegmentMean op.
  203. *@li indices: A Tensor. Must be one of the following types: int32, int64. \n
  204. indices passed to the corresponding SparseSegmentMean op.
  205. *@li segment_ids: A Tensor of type int32. segment_ids passed to the \n
  206. corresponding SparseSegmentMean op.
  207. *@li output_dim0: A Tensor of type int32. dimension 0 of "x" passed to \n
  208. SparseSegmentMean op.
  209. *@par Outputs:
  210. *y:A Tensor. Has the same type as grad.
  211. */
  212. REG_OP(SparseSegmentMeanGrad)
  213. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  214. .INPUT(indices, TensorType({DT_INT32}))
  215. .INPUT(segment_ids, TensorType({DT_INT32}))
  216. .INPUT(output_dim0, TensorType({DT_INT32}))
  217. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  218. .OP_END_FACTORY_REG(SparseSegmentMeanGrad)
  219. /**
  220. *@brief Computes the gradient of igamma(a, x) wrt a
  221. *@par Inputs:
  222. *The input a and x must have the same type. Inputs include: \n
  223. *@li a:A Tensor. Must be one of the following types: float32, double.
  224. *@li x:A Tensor. Must have the same type as a.
  225. *@par Outputs:
  226. *y:A Tensor. Has the same type as a.
  227. */
  228. REG_OP(IgammaGradA)
  229. .INPUT(a, TensorType({DT_FLOAT, DT_DOUBLE}))
  230. .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE}))
  231. .OUTPUT(z, TensorType({DT_FLOAT, DT_DOUBLE}))
  232. .OP_END_FACTORY_REG(IgammaGradA)
  233. /**
  234. *@brief Initialize data process channel.
  235. *@par Attributes:
  236. *channel_name: A string. Default "".
  237. */
  238. REG_OP(InitData)
  239. .ATTR(channel_name, String, "")
  240. .OP_END_FACTORY_REG(InitData)
  241. /**
  242. *@brief Get the next batch of data in data processing.
  243. *@par Attributes:
  244. *@li output_types: A nested structure of DType objects corresponding to each \n
  245. component of an element of this dataset.
  246. *@li output_shapes: A nested structure of TensorShape objects corresponding \n
  247. to each component of an element of this dataset.
  248. *@li channel_name: A string. Default "".
  249. *@par Outputs:
  250. *y:A nested structure of Tensor objects.
  251. */
  252. REG_OP(GetNext)
  253. .DYNAMIC_OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64,
  254. DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL}))
  255. .ATTR(output_types, ListInt, {})
  256. .ATTR(output_shapes, ListListInt, {})
  257. .ATTR(output_num, Int, 1)
  258. .ATTR(channel_name, String, "")
  259. .OP_END_FACTORY_REG(GetNext)
  260. /**
  261. *@brief End of sequence.
  262. *@par Inputs:
  263. *x: A Tensor of type uint8.
  264. *@par Outputs:
  265. *y: A Tensor. Has the same type as "x".
  266. */
  267. REG_OP(EndOfSequence)
  268. .INPUT(x, TensorType({DT_UINT8}))
  269. .OUTPUT(y, TensorType({DT_UINT8}))
  270. .OP_END_FACTORY_REG(EndOfSequence)
  271. /**
  272. *@brief: Computes the Gauss error function of `x` element-wise.
  273. *@par Inputs:\n
  274. *x: A Tensor of type float16 or float32.
  275. *@par Outputs:
  276. *y: A Tensor. Has the same type as "x".
  277. */
  278. REG_OP(Erf)
  279. .INPUT(x, TensorType::FloatingDataType())
  280. .OUTPUT(y, TensorType::FloatingDataType())
  281. .OP_END_FACTORY_REG(Erf)
  282. /**
  283. *@brief: Computes the Gauss complementary error function of "x" element-wise.
  284. *@par Inputs:\n
  285. *x: A Tensor of type float16 or float32.
  286. *@par Outputs:
  287. *y: A Tensor. Has the same type as "x".
  288. */
  289. REG_OP(Erfc)
  290. .INPUT(x, TensorType::FloatingDataType())
  291. .OUTPUT(y, TensorType::FloatingDataType())
  292. .OP_END_FACTORY_REG(Erfc)
  293. /**
  294. *@brief This operation returns a rank 1 histogram counting the number of entries in `values` \n
  295. * that fell into every bin.The bins are equal width and determined by the arguments \n
  296. * 'value_range' and 'nbins'. \n
  297. *@par Inputs:
  298. *Three inputs, including: \n
  299. *@li x: A Tensor of type float32,float16,int32.
  300. *@li range: A Tensor of type float32,float16,int32.
  301. *@li nbins: A Tensor of type int32.
  302. *@par Attributes:
  303. * dtype: An optional attribute. Defaults to "int32".
  304. *@par Outputs:
  305. *y: A Tensor. A Tensor of type int32.
  306. */
  307. REG_OP(HistogramFixedWidth)
  308. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  309. .INPUT(range, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  310. .INPUT(nbins, TensorType({DT_INT32}))
  311. .OUTPUT(y, TensorType({DT_INT32}))
  312. .ATTR(dtype, String, "int32")
  313. .OP_END_FACTORY_REG(HistogramFixedWidth)
  314. /**
  315. *@brief This operation returns a rank 1 histogram counting the number of entries in `values` \n
  316. * that fell into every bin.The bins are equal width and determined by the arguments \n
  317. * 'value_range' and 'nbins'. \n
  318. *@par Inputs:
  319. *Two inputs, including: \n
  320. *@li x: A Tensor of type float32,float16,int32.
  321. *@li range: A Tensor of type float32,float16,int32.
  322. *@par Attributes:
  323. *@li dtype: An optional attribute. Defaults to "int32".
  324. *@li nbins: A required attribute,the type is int32.
  325. *@par Outputs:
  326. *y: A Tensor. A Tensor of type int32.
  327. */
  328. REG_OP(HistogramFixedWidthD)
  329. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  330. .INPUT(range, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  331. .OUTPUT(y, TensorType({DT_INT32}))
  332. .REQUIRED_ATTR(nbins, Int)
  333. .ATTR(dtype, String, "int32")
  334. .OP_END_FACTORY_REG(HistogramFixedWidthD)
  335. /**
  336. *@brief Returns the next representable value of x1 in the direction of x2, element-wise.
  337. *@par Inputs:
  338. *The input X1 and x2 must have the same type. Inputs include: \n
  339. *@li x1:A Tensor. Must be one of the following types: float32, double.
  340. *@li x2:A Tensor. Must have the same type as x1.
  341. *@par Outputs:
  342. *output:A Tensor. Has the same type as x1.
  343. */
  344. REG_OP(NextAfter)
  345. .INPUT(x1, TensorType({DT_FLOAT, DT_DOUBLE}))
  346. .INPUT(x2, TensorType({DT_FLOAT, DT_DOUBLE}))
  347. .OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE}))
  348. .OP_END_FACTORY_REG(NextAfter)
  349. /**
  350. * *@brief Compute element-wise finiteness, return a boolean tensor.
  351. *
  352. * *@par Inputs:
  353. * *x:A Tensor.
  354. *
  355. * *@par Outputs:
  356. * *y:A Tensor. Has the same shape as x.
  357. *
  358. * */
  359. REG_OP(IsFinite)
  360. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  361. .OUTPUT(y, TensorType({DT_BOOL}))
  362. .OP_END_FACTORY_REG(IsFinite)
  363. /**
  364. * *@brief Computes the complex absolute value of a tensor.
  365. *
  366. * *@par Inputs:
  367. * *x:A Tensor.
  368. *
  369. * *@par Outputs:
  370. * *y:A tensor of type `float` or `double` that is the absolute value of each element in `x`.
  371. *
  372. * */
  373. REG_OP(ComplexAbs)
  374. .INPUT(x, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  375. .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE}))
  376. .ATTR(Tout, Type, DT_FLOAT)
  377. .OP_END_FACTORY_REG(ComplexAbs)
  378. /**
  379. * *@brief Returns which elements of x are NaN.
  380. *
  381. * *@par Inputs:
  382. * *x:A Tensor.
  383. *
  384. * *@par Outputs:
  385. * *y:A Tensor. Has the same shape as x.
  386. *
  387. * */
  388. REG_OP(IsNan)
  389. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  390. .OUTPUT(y, TensorType({DT_BOOL}))
  391. .OP_END_FACTORY_REG(IsNan)
  392. /**
  393. * *@brief Returns the real part of a complex number.
  394. *
  395. * *@par Inputs:
  396. * *input:A Tensor.
  397. *
  398. * *@par Outputs:
  399. * *output:A Tensor. Has the same shape as input.
  400. *
  401. * */
  402. REG_OP(Real)
  403. .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  404. .OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE}))
  405. .ATTR(Tout, Type, DT_FLOAT)
  406. .OP_END_FACTORY_REG(Real)
  407. /**
  408. * *@brief Returns the complex conjugate of a complex number.
  409. *
  410. * *@par Inputs:
  411. * *input:A Tensor.
  412. *
  413. * *@par Outputs:
  414. * *output:A Tensor. Has the same shape as input.
  415. *
  416. * */
  417. REG_OP(Conj)
  418. .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  419. .OUTPUT(output, TensorType({DT_COMPLEX64, DT_COMPLEX128}))
  420. .OP_END_FACTORY_REG(Conj)
  421. /**
  422. * *@brief The negative log likelihood loss.
  423. *
  424. * *@par Inputs:
  425. * *The input x and weight must have the same type. Inputs include: \n
  426. * *@li x:A Tensor. Must be the type: float32.
  427. * *@li target:A Tensor. Must be the type: int32.
  428. * *@li weight:A Tensor. Must be the type: float32.
  429. *
  430. * *@par Attributes:
  431. * *@li reduction: An optional attribute. Defaults to "mean".
  432. *
  433. * *@par Outputs:
  434. * *Two outputs, including:
  435. * *@li y: A Tensor. Must be the following type: float32.
  436. * *@li total_weight: A Tensor. Must be the type: float32.
  437. *
  438. * */
  439. REG_OP(NLLLoss)
  440. .INPUT(x, TensorType({DT_FLOAT}))
  441. .INPUT(target, TensorType({DT_INT32}))
  442. .INPUT(weight, TensorType({DT_FLOAT}))
  443. .OUTPUT(y, TensorType({DT_FLOAT}))
  444. .OUTPUT(total_weight, TensorType({DT_FLOAT}))
  445. .ATTR(reduction, String, "mean")
  446. .OP_END_FACTORY_REG(NLLLoss)
  447. /**
  448. * *@brief The negative log likelihood loss grad.
  449. * *@par Inputs:
  450. * *Inputs include:
  451. * *@li x:A Tensor. Must be the type: float32.
  452. * *@li y_grad:A Tensor. Must be the type: float32.
  453. * *@li target:A Tensor. Must be the type: int32.
  454. * *@li weight:A Tensor. Must be the type: float32.
  455. * *@li total_weight:A Tensor. Must be the type: float32.
  456. *
  457. * *@par Attributes:
  458. * *@li reduction: An optional attribute. Defaults to "mean".
  459. *
  460. * *@par Outputs:
  461. * *One outputs, including:
  462. * *@li x_grad: A Tensor. Must be the following type: float32.
  463. *
  464. * */
  465. REG_OP(NLLLossGrad)
  466. .INPUT(x, TensorType({DT_FLOAT}))
  467. .INPUT(y_grad, TensorType({DT_FLOAT}))
  468. .INPUT(target, TensorType({DT_INT32}))
  469. .INPUT(weight, TensorType({DT_FLOAT}))
  470. .INPUT(total_weight, TensorType({DT_FLOAT}))
  471. .OUTPUT(x_grad, TensorType({DT_FLOAT}))
  472. .ATTR(reduction, String, "mean")
  473. .OP_END_FACTORY_REG(NLLLossGrad)
  474. } // namespace ge
  475. #endif // GE_OP_MATH_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示