You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 54 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_TRAINING_OPS_H
  17. #define GE_OP_TRAINING_OPS_H
  18. #include "../graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Updates "var" according to the AdaMax algorithm.\n
  22. * t-1 mean previous period.
  23. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  24. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  25. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  26. *
  27. *@attention Constraints:\n
  28. * the input tensors must have the same shape.
  29. *
  30. *@par Inputs:
  31. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  32. * Should be from a Variable().
  33. *@li m: A mutable tensor. Has the same type as "var".
  34. * Should be from a Variable().
  35. *@li v: A mutable tensor. Has the same type as "var".
  36. * Should be from a Variable().
  37. *@li beta1_power: A scalar. Has the same type as "var".
  38. *@li lr: learning_rate. A scalar. Has the same type as "var".
  39. *@li beta1: A scalar. Has the same type as "var".
  40. *@li beta2: A scalar. Has the same type as "var".
  41. *@li epsilon: A scalar. Has the same type as "var".
  42. *@li grad: A tensor for the gradient. Has the same type as "var".
  43. *
  44. *@par Attributes:\n
  45. * use_locking: An optional bool. Defaults to "False".
  46. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  47. * by a lock; otherwise the behavior is undefined, but may exhibit less
  48. * contention.
  49. *
  50. *@par Outputs:
  51. * var: A mutable tensor. Has the same type as input "var".
  52. *
  53. */
  54. REG_OP(ApplyAdaMax)
  55. .INPUT(var, TensorType::NumberType())
  56. .INPUT(m, TensorType::NumberType())
  57. .INPUT(v, TensorType::NumberType())
  58. .INPUT(beta1_power, TensorType::NumberType())
  59. .INPUT(lr, TensorType::NumberType())
  60. .INPUT(beta1, TensorType::NumberType())
  61. .INPUT(beta2, TensorType::NumberType())
  62. .INPUT(epsilon, TensorType::NumberType())
  63. .INPUT(grad, TensorType::NumberType())
  64. .OUTPUT(var, TensorType::NumberType())
  65. .ATTR(use_locking, Bool, false)
  66. .OP_END_FACTORY_REG(ApplyAdaMax)
  67. /**
  68. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  69. *@par Inputs:
  70. * Five inputs, including:
  71. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  72. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  73. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  74. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  75. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  76. *@par Attributes:
  77. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  78. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  79. *@par Outputs:
  80. *var: A Tensor. Has the same type and format as input "var".
  81. */
  82. REG_OP(SparseApplyAdagrad)
  83. .INPUT(var, TensorType({DT_FLOAT}))
  84. .INPUT(accum, TensorType({DT_FLOAT}))
  85. .INPUT(lr, TensorType({DT_FLOAT}))
  86. .INPUT(grad, TensorType({DT_FLOAT}))
  87. .INPUT(indices, TensorType({DT_INT32}))
  88. .OUTPUT(var, TensorType({DT_FLOAT}))
  89. .ATTR(use_locking, Bool, false)
  90. .ATTR(update_slots, Bool, true)
  91. .OP_END_FACTORY_REG(SparseApplyAdagrad)
  92. /**
  93. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  94. *@par Inputs:
  95. * Four inputs, including:
  96. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  97. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  98. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  99. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  100. *@par Attributes:
  101. *@li lr: Required, used for computation.
  102. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  103. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  104. *@par Outputs:
  105. *var: A Tensor. Has the same type and format as input "var".
  106. */
  107. REG_OP(SparseApplyAdagradD)
  108. .INPUT(var, TensorType({DT_FLOAT}))
  109. .INPUT(accum, TensorType({DT_FLOAT}))
  110. .INPUT(grad, TensorType({DT_FLOAT}))
  111. .INPUT(indices, TensorType({DT_INT32}))
  112. .OUTPUT(var, TensorType({DT_FLOAT}))
  113. .REQUIRED_ATTR(lr, Float)
  114. .ATTR(use_locking, Bool, false)
  115. .ATTR(update_slots, Bool, true)
  116. .OP_END_FACTORY_REG(SparseApplyAdagradD)
  117. /**
  118. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  119. *@par Inputs:
  120. * Five inputs, including:
  121. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  122. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  123. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  124. *@li epsilon: An NCHW, NHWC, or ND Tensor of type float32.
  125. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  126. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  127. *@par Attributes:
  128. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  129. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  130. *@par Outputs:
  131. *var: A Tensor. Has the same type and format as input "var".
  132. */
  133. REG_OP(SparseApplyAdagradV2)
  134. .INPUT(var, TensorType({DT_FLOAT}))
  135. .INPUT(accum, TensorType({DT_FLOAT}))
  136. .INPUT(lr, TensorType({DT_FLOAT}))
  137. .INPUT(epsilon, TensorType({DT_FLOAT}))
  138. .INPUT(grad, TensorType({DT_FLOAT}))
  139. .INPUT(indices, TensorType({DT_INT32}))
  140. .OUTPUT(var, TensorType({DT_FLOAT}))
  141. .ATTR(use_locking, Bool, false)
  142. .ATTR(update_slots, Bool, true)
  143. .OP_END_FACTORY_REG(SparseApplyAdagradV2)
  144. /**
  145. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  146. *@par Inputs:
  147. * Four inputs, including:
  148. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  149. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  150. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  151. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  152. *@par Attributes:
  153. *@li lr: Required, used for computation.
  154. *@li epsilon: Required, used for computation.
  155. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  156. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  157. *@par Outputs:
  158. *var: A Tensor. Has the same type and format as input "var".
  159. *accum: A Tensor. Has the same type and format as input "accum".
  160. */
  161. REG_OP(SparseApplyAdagradV2D)
  162. .INPUT(var, TensorType({DT_FLOAT}))
  163. .INPUT(accum, TensorType({DT_FLOAT}))
  164. .INPUT(grad, TensorType({DT_FLOAT}))
  165. .INPUT(indices, TensorType({DT_INT32}))
  166. .OUTPUT(var, TensorType({DT_FLOAT}))
  167. .OUTPUT(accum, TensorType({DT_FLOAT}))
  168. .REQUIRED_ATTR(lr, Float)
  169. .REQUIRED_ATTR(epsilon, Float)
  170. .ATTR(use_locking, Bool, false)
  171. .ATTR(update_slots, Bool, true)
  172. .OP_END_FACTORY_REG(SparseApplyAdagradV2D)
  173. /**
  174. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  175. * want to use Nesterov momentum.\n
  176. * computing process: \n
  177. * accum = accum * momentum + grad\n
  178. * var -= lr * accum
  179. *
  180. *@attention Constraints:\n
  181. * the input tensors must have the same shape.
  182. *
  183. *@par Inputs:
  184. *@li var: A mutable tensor. Should be from a Variable().
  185. *@li accum: A mutable tensor. Has the same type as "var".
  186. * Should be from a Variable().
  187. *@li lr: A scalar. Has the same type as "var".
  188. *@li grad: A tensor for the gradient. Has the same type as "var".
  189. *
  190. *@par Attributes:
  191. *@li use_nesterov: An optional bool. Defaults to "False".
  192. * If "True", the tensor passed to compute grad will be
  193. * var - lr * momentum * accum, so in the end, the var you get is actually
  194. * var - lr * momentum * accum.
  195. *
  196. *@li use_locking: An optional bool. Defaults to "False".\n
  197. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  198. * otherwise the behavior is undefined, but may exhibit less contention.
  199. *
  200. *@par Outputs:
  201. * var: A mutable tensor. Has the same type as input "var".
  202. *
  203. */
  204. REG_OP(ApplyMomentum)
  205. .INPUT(var, TensorType::NumberType())
  206. .INPUT(accum, TensorType::NumberType())
  207. .INPUT(lr, TensorType::NumberType())
  208. .INPUT(grad, TensorType::NumberType())
  209. .INPUT(momentum, TensorType::NumberType())
  210. .OUTPUT(var, TensorType::NumberType())
  211. .ATTR(use_nesterov, Bool, false)
  212. .ATTR(use_locking, Bool, false)
  213. .OP_END_FACTORY_REG(ApplyMomentum)
  214. REG_OP(ApplyMomentumCCE)
  215. .INPUT(var, TensorType::NumberType())
  216. .INPUT(accum, TensorType::NumberType())
  217. .INPUT(lr, TensorType::NumberType())
  218. .INPUT(grad, TensorType::NumberType())
  219. .INPUT(momentum, TensorType::NumberType())
  220. .OUTPUT(var, TensorType::NumberType())
  221. .ATTR(use_nesterov, Bool, false)
  222. .ATTR(use_locking, Bool, false)
  223. .OP_END_FACTORY_REG(ApplyMomentumCCE)
  224. /**
  225. *@brief Updates "var" according to the AddSign update.\n
  226. * t-1 mean previous period.
  227. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  228. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  229. * var <- var - lr * update
  230. *
  231. *@attention Constraints:\n
  232. * the input tensors must have the same shape.
  233. *
  234. *@par Inputs:
  235. *@li var: A mutable tensor. Should be from a Variable().
  236. *@li m: A mutable tensor. Has the same type as "var".
  237. * Should be from a Variable().
  238. *@li lr: A scalar. Has the same type as "var".
  239. *@li logbase: A scalar. Has the same type as "var".
  240. *@li sign_decay: A scalar. Has the same type as "var".
  241. *@li beta: A scalar. Has the same type as "var".
  242. *@li grad: A tensor for the gradient. Has the same type as "var".
  243. *
  244. *@par Attributes:
  245. * use_locking: An optional bool. Defaults to "False".
  246. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  247. * by a lock; otherwise the behavior is undefined, but may exhibit less
  248. * contention.
  249. *
  250. *@par Outputs:
  251. * var: A mutable tensor. Has the same type as input "var".
  252. *
  253. */
  254. REG_OP(ApplyPowerSign)
  255. .INPUT(var, TensorType::NumberType())
  256. .INPUT(m, TensorType::NumberType())
  257. .INPUT(lr, TensorType::NumberType())
  258. .INPUT(logbase, TensorType::NumberType())
  259. .INPUT(sign_decay, TensorType::NumberType())
  260. .INPUT(beta, TensorType::NumberType())
  261. .INPUT(grad, TensorType::NumberType())
  262. .OUTPUT(var, TensorType::NumberType())
  263. .ATTR(use_locking, Bool, false)
  264. .OP_END_FACTORY_REG(ApplyPowerSign)
  265. /**
  266. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n
  267. * prox_v = var - alpha * delta\n
  268. * var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
  269. *
  270. *@attention Constraints:\n
  271. * the input tensors must have the same shape.
  272. *
  273. *@par Inputs:
  274. *@li var: A mutable tensor. Should be from a Variable().
  275. *@li alpha: A scalar. Has the same type as "var".
  276. *@li l1: A scalar. Has the same type as "var".
  277. *@li l2: A scalar. Has the same type as "var".
  278. *@li delta: A tensor. Has the same type as "var". The change.
  279. *
  280. *@par Attributes:
  281. * use_locking: An optional bool. Defaults to "False".
  282. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  283. * by a lock; otherwise the behavior is undefined, but may exhibit less
  284. * contention.
  285. *
  286. *@par Outputs:
  287. * var: A mutable tensor. Has the same type as input "var".
  288. *
  289. */
  290. REG_OP(ApplyProximalGradientDescent)
  291. .INPUT(var, TensorType::NumberType())
  292. .INPUT(alpha, TensorType::NumberType())
  293. .INPUT(l1, TensorType::NumberType())
  294. .INPUT(l2, TensorType::NumberType())
  295. .INPUT(delta, TensorType::NumberType())
  296. .OUTPUT(var, TensorType::NumberType())
  297. .ATTR(use_locking, Bool, false)
  298. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  299. /**
  300. *@brief Updates "var" according to the AddSign update.
  301. *@par Inputs:
  302. *Seven inputs, including:
  303. * @li var: A mutable Tensor of type TensorType::NumberType().
  304. * Should be a Variable Tensor.
  305. * @li m: A mutable Tensor of the same type as "var".
  306. * Should be a Variable Tensor.
  307. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  308. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  309. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  310. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  311. * @li grad: A Tensor of the same type as "var", for the gradient.
  312. *@par Attributes:
  313. *use_locking: An optional bool. Defaults to "False".
  314. * If "True", updating of the "var" and "m" tensors will be
  315. * protected by a lock; otherwise the behavior is undefined,
  316. * but may exhibit less contention.
  317. *@par Outputs:
  318. *var: A mutable Tensor. Has the same type as "var".
  319. */
  320. REG_OP(ApplyAddSign)
  321. .INPUT(var, TensorType::NumberType())
  322. .INPUT(m, TensorType::NumberType())
  323. .INPUT(lr, TensorType::NumberType())
  324. .INPUT(alpha, TensorType::NumberType())
  325. .INPUT(sign_decay, TensorType::NumberType())
  326. .INPUT(beta, TensorType::NumberType())
  327. .INPUT(grad, TensorType::NumberType())
  328. .OUTPUT(var, TensorType::NumberType())
  329. .ATTR(use_locking, Bool, false)
  330. .OP_END_FACTORY_REG(ApplyAddSign)
  331. /**
  332. *@brief Updates "var" according to the centered RMSProp algorithm.\n
  333. * The centered RMSProp algorithm uses an estimate of the centered second moment
  334. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  335. * uses the (uncentered) second moment. This often helps with training, but is
  336. * slightly more expensive in terms of computation and memory.
  337. *
  338. * t-1 mean previous period.
  339. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  340. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  341. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  342. * var <- var - mom\n
  343. *
  344. *@attention Constraints:\n
  345. *@li in dense implementation of this algorithm, mg, ms, and mom will
  346. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  347. * and mom will not update in iterations during which the grad is zero.
  348. *@li the input tensors must have the same shape.
  349. *
  350. *@par Inputs:
  351. *@li var: A mutable tensor. Should be from a Variable().
  352. *@li mg: A mutable tensor. Has the same type as "var".
  353. * Should be from a Variable().
  354. *@li ms: A mutable tensor. Has the same type as "var".
  355. * Should be from a Variable().
  356. *@li mom: A mutable tensor. Has the same type as "var".
  357. * Should be from a Variable().
  358. *@li lr: A scalar. Has the same type as "var".
  359. *@li rho: A scalar. Has the same type as "var".
  360. *@li momentum: A tensor. Has the same type as "var".
  361. *@li epsilon: A scalar. Has the same type as "var".
  362. *@li grad: A tensor for the gradient. Has the same type as "var".
  363. *
  364. *@par Attributes:
  365. * use_locking: An optional bool. Defaults to "False".
  366. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  367. * by a lock; otherwise the behavior is undefined, but may exhibit less
  368. * contention.
  369. *
  370. *@par Outputs:
  371. * var: A mutable tensor. Has the same type as input "var".
  372. *
  373. */
  374. REG_OP(ApplyCenteredRMSProp)
  375. .INPUT(var, TensorType::NumberType())
  376. .INPUT(mg, TensorType::NumberType())
  377. .INPUT(ms, TensorType::NumberType())
  378. .INPUT(mom, TensorType::NumberType())
  379. .INPUT(lr, TensorType::NumberType())
  380. .INPUT(rho, TensorType::NumberType())
  381. .INPUT(momentum, TensorType::NumberType())
  382. .INPUT(epsilon, TensorType::NumberType())
  383. .INPUT(grad, TensorType::NumberType())
  384. .OUTPUT(var, TensorType::NumberType())
  385. .ATTR(use_locking, Bool, false)
  386. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  387. /**
  388. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.\n
  389. * var -= delta * alpha
  390. *
  391. *@attention Constraints:\n
  392. * the input tensors must have the same shape.
  393. *
  394. *@par Inputs:
  395. *@li var: A mutable tensor. Should be from a Variable().
  396. *@li alpha: A scalar. Has the same type as "var".
  397. *@li delta: A tensor for the change. Has the same type as "var".
  398. *
  399. *@par Attributes:
  400. * use_locking: An optional bool. Defaults to "False".
  401. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  402. * by a lock; otherwise the behavior is undefined, but may exhibit less
  403. * contention.
  404. *
  405. *@par Outputs:
  406. * var: A mutable tensor. Has the same type as input "var".
  407. *
  408. */
  409. REG_OP(ApplyGradientDescent)
  410. .INPUT(var, TensorType::NumberType())
  411. .INPUT(alpha, TensorType::NumberType())
  412. .INPUT(delta, TensorType::NumberType())
  413. .OUTPUT(var, TensorType::NumberType())
  414. .ATTR(use_locking, Bool, false)
  415. .OP_END_FACTORY_REG(ApplyGradientDescent)
  416. /**
  417. *@brief Updates "var" according to the adagrad scheme.\n
  418. * accum += grad * grad\n
  419. * var -= lr * grad * (1 / sqrt(accum))
  420. *
  421. *@attention Constraints:\n
  422. * the input tensors must have the same shape.
  423. *
  424. *@par Inputs:
  425. *@li var: A mutable tensor. Should be from a Variable().
  426. *@li accum: A mutable tensor. Has the same type as "var".
  427. * Should be from a Variable().
  428. *@li lr: A scalar. Has the same type as "var".
  429. *@li grad: A tensor for the gradient. Has the same type as "var".
  430. *
  431. *@par Attributes:
  432. * use_locking: An optional bool. Defaults to "False".
  433. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  434. * by a lock; otherwise the behavior is undefined, but may exhibit less
  435. * contention.
  436. *
  437. *@par Outputs:
  438. * var: A mutable tensor. Has the same type as input "var".
  439. *
  440. */
  441. REG_OP(ApplyAdagrad)
  442. .INPUT(var, TensorType::NumberType())
  443. .INPUT(accum, TensorType::NumberType())
  444. .INPUT(lr, TensorType::NumberType())
  445. .INPUT(grad, TensorType::NumberType())
  446. .OUTPUT(var, TensorType::NumberType())
  447. .ATTR(update_slots, Bool, true)
  448. .ATTR(use_locking, Bool, false)
  449. .OP_END_FACTORY_REG(ApplyAdagrad)
  450. /**
  451. * @brief Updates "var" according to the adagradv2 scheme.\n
  452. * accum += grad * grad \n
  453. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  454. *
  455. * @attention Constraints:
  456. * the input tensors must have the same shape.
  457. *
  458. * @par Inputs:
  459. * @li var: A mutable tensor. Must be one of the data types defined in
  460. * TensorType::NumberType(). Should be from a Variable().
  461. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  462. * Variable().
  463. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  464. * from a Variable().
  465. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  466. * from a Variable().
  467. * @li epsilon: A scalar. Has the same type as "var".
  468. *
  469. * @par Attributes:
  470. * @li update_slots: An optional bool. Defaults to "True".
  471. * If "True", accum will be updated
  472. * @li use_locking: An optional bool. Defaults to "False".
  473. * If "True", updating of the "var" tensor is protected by a lock;
  474. * otherwise the behavior is undefined, but may exhibit less contention.
  475. *
  476. * @par Outputs:
  477. * var: A mutable tensor. Has the same type as input "var".
  478. *
  479. */
  480. REG_OP(ApplyAdagradV2)
  481. .INPUT(var, TensorType::NumberType())
  482. .INPUT(accum, TensorType::NumberType())
  483. .INPUT(lr, TensorType::NumberType())
  484. .INPUT(epsilon, TensorType::NumberType())
  485. .INPUT(grad, TensorType::NumberType())
  486. .OUTPUT(var, TensorType::NumberType())
  487. .ATTR(update_slots, Bool, true)
  488. .ATTR(use_locking, Bool, false)
  489. .OP_END_FACTORY_REG(ApplyAdagradV2)
  490. /**
  491. * @brief Updates "var" according to the adagradv2 scheme.\n
  492. * accum += grad * grad \n
  493. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  494. *
  495. * @attention Constraints:
  496. * the input tensors must have the same shape.
  497. *
  498. * @par Inputs:
  499. * @li var: A mutable tensor. Must be one of the data types defined in
  500. * TensorType::NumberType(). Should be from a Variable().
  501. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  502. * Variable().
  503. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  504. * from a Variable().
  505. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  506. * from a Variable().
  507. *
  508. * @par Attributes:
  509. * @li epsilon: A scalar. Has the same type as "var".
  510. * @li update_slots: An optional bool. Defaults to "True".
  511. * If "True", accum will be updated
  512. * @li use_locking: An optional bool. Defaults to "False".
  513. * If "True", updating of the "var" tensor is protected by a lock;
  514. * otherwise the behavior is undefined, but may exhibit less contention.
  515. *
  516. * @par Outputs:
  517. * var: A mutable tensor. Has the same type as input "var".
  518. *
  519. */
  520. REG_OP(ApplyAdagradV2D)
  521. .INPUT(var, TensorType::NumberType())
  522. .INPUT(accum, TensorType::NumberType())
  523. .INPUT(lr, TensorType::NumberType())
  524. .INPUT(grad, TensorType::NumberType())
  525. .OUTPUT(var, TensorType::NumberType())
  526. .OUTPUT(accum, TensorType::NumberType())
  527. .REQUIRED_ATTR(epsilon, Float)
  528. .ATTR(update_slots, Bool, true)
  529. .ATTR(use_locking, Bool, false)
  530. .OP_END_FACTORY_REG(ApplyAdagradV2D)
  531. /**
  532. *@brief Updates "var" according to the proximal adagrad scheme.
  533. *@par Inputs:
  534. *Eight inputs, including:
  535. * @li var: A mutable Tensor. Must be one of the following types:
  536. * TensorType::NumberType(). Should be a Variable Tensor.
  537. * @li gradient_accumulator: A mutable Tensor. Must have the same
  538. * type as "var". Should be a Variable Tensor.
  539. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  540. * Should be a Variable Tensor.
  541. * @li grad: A Tensor of the same type as "var", for the gradient.
  542. * @li lr: A Tensor of the same type as "var".
  543. * Scaling factor. Must be a scalar.
  544. * @li l1: A Tensor of the same type as "var".
  545. * L1 regulariation. Must be a scalar.
  546. * @li l2: A Tensor of the same type as "var".
  547. * L2 regulariation. Must be a scalar.
  548. * @li global_step: A Tensor of type int32 or int64.
  549. * Training step number. Must be a scalar.
  550. *@par Attributes:
  551. *use_locking: An optional bool. Defaults to "False".
  552. * If "True", updating of the var and accum tensors will be
  553. * protected by a lock; otherwise the behavior is undefined,
  554. * but may exhibit less contention.
  555. *@par Outputs:
  556. *var: A mutable Tensor. Has the same type as "var".
  557. */
  558. REG_OP(ApplyAdagradDA)
  559. .INPUT(var, TensorType::NumberType())
  560. .INPUT(gradient_accumulator, TensorType::NumberType())
  561. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  562. .INPUT(grad, TensorType::NumberType())
  563. .INPUT(lr, TensorType::NumberType())
  564. .INPUT(l1, TensorType::NumberType())
  565. .INPUT(l2, TensorType::NumberType())
  566. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  567. .OUTPUT(var, TensorType::NumberType())
  568. .ATTR(use_locking, Bool, false)
  569. .OP_END_FACTORY_REG(ApplyAdagradDA)
  570. /**
  571. *@brief Returns the dimension index in the destination data format given the one in
  572. * the source data format.
  573. *
  574. *@par Inputs:
  575. * x: A tensor of type int32 or int64.
  576. * A Tensor with each element as a dimension index in source data format.
  577. * Must be in the range [-4, 4).
  578. *
  579. *@par Attributes:
  580. *@li src_format: An optional string. Defaults to NHWC.
  581. * source data format.
  582. *@li dst_format: An optional string. Defaults to NCHW.
  583. * destination data format.
  584. *
  585. *@par Outputs:
  586. * y: A tensor. Has the same type as "x".
  587. *
  588. */
  589. REG_OP(DataFormatDimMap)
  590. .INPUT(x, TensorType::IndexNumberType())
  591. .ATTR(src_format, String, "NHWC")
  592. .ATTR(dst_format, String, "NCHW")
  593. .OUTPUT(y, TensorType::IndexNumberType())
  594. .OP_END_FACTORY_REG(DataFormatDimMap)
  595. /**
  596. * @brief Implements stochastic gradient descent (optionally with momentum).\n
  597. * Nesterov momentum is based on the formula from
  598. * On the importance of initialization and momentum in deep learning.\n
  599. * @par Inputs:
  600. * @li parameters: A mutable tensor of type float16 or float32.\n
  601. * Specifies the iterable of parameters to optimize or dicts defining parameter
  602. * groups.
  603. * @li gradient: A tensor of type float16 or float32.\n
  604. * Specifies the gradient of training step.
  605. * @li learning_rate: A tensor of type float16 or float32.\n
  606. * Specifies the learing_rate of training step.
  607. * @li accum: A tensor of type float16 or float32.
  608. * Specifies the velocity of training step.
  609. * @li momentum: A tensor of type float16 or float32.
  610. * Specifies the momentum factor.
  611. * @li stat: A tensor of type float16 or float32.
  612. * Specifies the status representing the first step or not.
  613. * @par Attributes:
  614. * @li dampening: An optional float, specifying the dampening for momentum.
  615. * Defaults to "0.0".
  616. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  617. * "0.0".
  618. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  619. * momentum. Defaults to "False".
  620. * @par Outputs:
  621. * parameters: A mutable tensor same as input "parameters".
  622. * @see ApplyMomentum()
  623. */
  624. REG_OP(SGD)
  625. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  626. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  627. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  628. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  629. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  630. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  631. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  632. .ATTR(dampening, Float, 0.0)
  633. .ATTR(weight_decay, Float, 0.0)
  634. .ATTR(nesterov, Bool, false)
  635. .OP_END_FACTORY_REG(SGD)
  636. /**
  637. * @brief Updates "var" according to the RMSProp algorithm.\n
  638. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  639. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  640. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  641. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  642. * var <- var - mom\n
  643. *
  644. * @par Inputs:
  645. * @li var: A mutable tensor. Must be one of the data types defined in\n
  646. * TensorType::NumberType(). Should be from a Variable().
  647. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  648. * Variable().
  649. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  650. * Variable().
  651. * @li lr: A scalar. Must have the same type as "var".
  652. * @li rho: A scalar. Must have the same type as "var".
  653. * @li momentum: A scalar. Must have the same type as "var".
  654. * @li epsilon: A scalar. Must have the same type as "var".
  655. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  656. *
  657. * @par Attributes:
  658. * use_locking: An optional "bool". Defaults to "False". If "True", updating of\n
  659. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the\n
  660. * behavior is undefined, but may exhibit less contention.
  661. *
  662. * @par Outputs:
  663. * var: A mutable tensor. Has the same type as input "var".
  664. *
  665. * @attention Constraints:
  666. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will \n
  667. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom" \n
  668. * will not update in iterations during which "grad" is 0.
  669. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  670. */
  671. REG_OP(ApplyRMSProp)
  672. .INPUT(var, TensorType::NumberType())
  673. .INPUT(ms, TensorType::NumberType())
  674. .INPUT(mom, TensorType::NumberType())
  675. .INPUT(lr, TensorType::NumberType())
  676. .INPUT(rho, TensorType::NumberType())
  677. .INPUT(momentum, TensorType::NumberType())
  678. .INPUT(epsilon, TensorType::NumberType())
  679. .INPUT(grad, TensorType::NumberType())
  680. .OUTPUT(var, TensorType::NumberType())
  681. .ATTR(use_locking, Bool, false)
  682. .OP_END_FACTORY_REG(ApplyRMSProp)
  683. /**
  684. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  685. * considered as an attribute.\n
  686. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  687. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  688. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  689. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  690. * var <- var - mom
  691. *
  692. * @par Inputs:
  693. * @li var: A mutable tensor. Must be one of the data types defined in\n
  694. * TensorType::NumberType(). Should be from a Variable().
  695. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  696. * Variable().
  697. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  698. * Variable().
  699. * @li lr: A scalar. Must have the same type as "var".
  700. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  701. *
  702. * @par Attributes:
  703. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating\n
  704. * of the "var", "ms", and "mom" tensors will be protected by a lock; \n
  705. * otherwise the behavior is undefined, but may exhibit less contention.
  706. * @li rho: A required scalar. Must have the same type as "var".
  707. * @li momentum: A required scalar. Must have the same type as "var".
  708. * @li epsilon: A required scalar. Must have the same type as "var".
  709. *
  710. * @par Outputs:
  711. * var: A mutable tensor. Must have the same type as input "var".
  712. *
  713. * @attention Constraints:
  714. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will\n
  715. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"\n
  716. * will not update in iterations during which "grad" is 0.
  717. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  718. */
  719. REG_OP(ApplyRMSPropD)
  720. .INPUT(var, TensorType::NumberType())
  721. .INPUT(ms, TensorType::NumberType())
  722. .INPUT(mom, TensorType::NumberType())
  723. .INPUT(lr, TensorType::NumberType())
  724. .INPUT(grad, TensorType::NumberType())
  725. .OUTPUT(var, TensorType::NumberType())
  726. .OUTPUT(ms, TensorType::NumberType())
  727. .OUTPUT(mom, TensorType::NumberType())
  728. .REQUIRED_ATTR(rho, Float)
  729. .REQUIRED_ATTR(momentum, Float)
  730. .REQUIRED_ATTR(epsilon, Float)
  731. .ATTR(use_locking, Bool, false)
  732. .OP_END_FACTORY_REG(ApplyRMSPropD)
  733. /**
  734. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  735. *@par Inputs:
  736. *Six inputs, including:
  737. * @li var: A mutable Tensor of type TensorType::NumberType().
  738. * Should be from a Variable().
  739. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  740. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  741. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  742. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  743. * @li grad: A Tensor of the same type as "var", for the gradient.
  744. *@par Attributes:
  745. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  746. *@par Outputs:
  747. *var: A mutable Tensor. Has the same type as "var".
  748. */
  749. REG_OP(ApplyProximalAdagrad)
  750. .INPUT(var, TensorType::NumberType())
  751. .INPUT(accum, TensorType::NumberType())
  752. .INPUT(lr, TensorType::NumberType())
  753. .INPUT(l1, TensorType::NumberType())
  754. .INPUT(l2, TensorType::NumberType())
  755. .INPUT(grad, TensorType::NumberType())
  756. .OUTPUT(var, TensorType::NumberType())
  757. .ATTR(use_locking, Bool, false)
  758. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  759. /**
  760. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  761. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  762. * Only the indices into the first dimensions of "var" and "accum" are updated.
  763. *@par Inputs:
  764. * Seven inputs, including:\n
  765. * @li var: A mutable Tensor.\n
  766. * TensorType::NumberType(). Should be a Variable Tensor.
  767. * @li accum: A mutable Tensor of the same type as "var".\n
  768. * Should be a Variable Tensor.
  769. * @li lr: A Tensor of the same type as "var".\n
  770. * Scaling factor. Must be a scalar.
  771. * @li l1: A Tensor of the same type as "var".\n
  772. * L1 regulariation. Must be a scalar.
  773. * @li l2: A Tensor of the same type as "var".\n
  774. * L2 regulariation. Must be a scalar.
  775. * @li grad: A Tensor. Has the same type as "var". \n
  776. * The gradient.
  777. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  778. * TensorType::IndexNumberType().
  779. *@par Attributes:
  780. *use_locking: An optional bool. Defaults to "False".\n
  781. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  782. * If "False", the behavior is undefined, but may exhibit less contention.
  783. *@par Outputs:
  784. *var: A mutable Tensor. Has the same type as "var".
  785. */
  786. REG_OP(SparseApplyProximalAdagrad)
  787. .INPUT(var, TensorType::NumberType())
  788. .INPUT(accum, TensorType::NumberType())
  789. .INPUT(lr, TensorType::NumberType())
  790. .INPUT(l1, TensorType::NumberType())
  791. .INPUT(l2, TensorType::NumberType())
  792. .INPUT(grad, TensorType::NumberType())
  793. .INPUT(indices, TensorType::IndexNumberType())
  794. .OUTPUT(var, TensorType::NumberType())
  795. .ATTR(use_locking, Bool, false)
  796. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  797. /**
  798. *@brief Updates "var" according to the Ftrl-proximal scheme.
  799. *@par Inputs:
  800. *Eight inputs, including:
  801. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  802. * Should be a Variable Tensor.
  803. * @li accum: A mutable Tensor of the same type as "var".
  804. * Should be a Variable Tensor.
  805. * @li linear: A mutable Tensor of the same type as "var".
  806. * Should be a Variable Tensor.
  807. * @li grad: A Tensor of the same type as "var", for the gradient.
  808. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  809. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  810. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  811. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  812. *@par Attributes:
  813. *use_locking: An optional bool. Defaults to "False".
  814. * If "True", updating of the "var" and "accum" tensors will be
  815. * protected by a lock; otherwise the behavior is undefined,
  816. * but may exhibit less contention.
  817. *@par Outputs:
  818. *var: A mutable Tensor. Has the same type as "var".
  819. */
  820. REG_OP(ApplyFtrl)
  821. .INPUT(var, TensorType::NumberType())
  822. .INPUT(accum, TensorType::NumberType())
  823. .INPUT(linear, TensorType::NumberType())
  824. .INPUT(grad, TensorType::NumberType())
  825. .INPUT(lr, TensorType::NumberType())
  826. .INPUT(l1, TensorType::NumberType())
  827. .INPUT(l2, TensorType::NumberType())
  828. .INPUT(lr_power, TensorType::NumberType())
  829. .OUTPUT(var, TensorType::NumberType())
  830. .ATTR(use_locking, Bool, false)
  831. .OP_END_FACTORY_REG(ApplyFtrl)
  832. /**
  833. *@brief Update "var" according to the Ftrl-proximal scheme.
  834. *@par Inputs:
  835. *Nine inputs, including:
  836. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  837. * Should be a Variable Tensor.
  838. * @li accum: A mutable Tensor of the same type as "var".
  839. * Should be a Variable Tensor.
  840. * @li linear: A mutable Tensor of the same type as "var".
  841. * Should be a Variable Tensor.
  842. * @li grad: A Tensor of the same type as "var", for the gradient.
  843. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  844. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  845. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  846. * @li l2_shrinkage: A Tensor of the same type as "var".
  847. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  848. *@par Attributes:
  849. *use_locking: An optional bool. Defaults to "False".
  850. * If "True", updating of the "var" and "accum" tensors will be
  851. * protected by a lock; otherwise the behavior is undefined,
  852. * but may exhibit less contention.
  853. *@par Outputs:
  854. *var: A mutable Tensor. Has the same type as "var".
  855. */
  856. REG_OP(ApplyFtrlV2)
  857. .INPUT(var, TensorType::NumberType())
  858. .INPUT(accum, TensorType::NumberType())
  859. .INPUT(linear, TensorType::NumberType())
  860. .INPUT(grad, TensorType::NumberType())
  861. .INPUT(lr, TensorType::NumberType())
  862. .INPUT(l1, TensorType::NumberType())
  863. .INPUT(l2, TensorType::NumberType())
  864. .INPUT(l2_shrinkage, TensorType::NumberType())
  865. .INPUT(lr_power, TensorType::NumberType())
  866. .OUTPUT(var, TensorType::NumberType())
  867. .ATTR(use_locking, Bool, false)
  868. .OP_END_FACTORY_REG(ApplyFtrlV2)
  869. /**
  870. *@brief Updates "var" according to the Adam algorithm.\n
  871. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  872. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  873. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  874. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  875. *
  876. *@attention Constraints:\n
  877. * *The input tensors must have the same shape.*
  878. *
  879. *@par Inputs:
  880. *@li var: A mutable Tensor of the type TensorType::NumberType().
  881. * Should be from a Variable().
  882. *@li m: A mutable Tensor of the same type as "var".
  883. * Should be from a Variable().
  884. *@li v: A mutable Tensor of the same type as "var".
  885. * Should be from a Variable().
  886. *@li beta1_power: A scalar of the same type as "var".
  887. *@li beta2_power: A scalar of the same type as "var".
  888. *@li lr: learning_rate. A scalar of the same type as "var".
  889. *@li beta1: A scalar of the same type as "var".
  890. *@li beta2: A scalar of the same type as "var".
  891. *@li epsilon: A scalar of the same type as "var".
  892. *@li grad: A Tensor of the same type as "var", for the gradient.
  893. *
  894. *@par Attributes:\n
  895. *@li use_locking: An optional bool. Defaults to "False".
  896. * If "True", updating of the "var", m", and "v" tensors will be protected
  897. * by a lock; otherwise the behavior is undefined, but may exhibit less
  898. * contention.
  899. *@li use_nesterov: An optional bool. Defaults to "False".
  900. If "True", uses the nesterov update.
  901. *
  902. *@par Outputs:
  903. * var: A mutable Tensor. Has the same type as intput "var".
  904. */
  905. REG_OP(ApplyAdam)
  906. .INPUT(var, TensorType::NumberType())
  907. .INPUT(m, TensorType::NumberType())
  908. .INPUT(v, TensorType::NumberType())
  909. .INPUT(beta1_power, TensorType::NumberType())
  910. .INPUT(beta2_power, TensorType::NumberType())
  911. .INPUT(lr, TensorType::NumberType())
  912. .INPUT(beta1, TensorType::NumberType())
  913. .INPUT(beta2, TensorType::NumberType())
  914. .INPUT(epsilon, TensorType::NumberType())
  915. .INPUT(grad, TensorType::NumberType())
  916. .OUTPUT(var, TensorType::NumberType())
  917. .OUTPUT(m, TensorType::NumberType())
  918. .OUTPUT(v, TensorType::NumberType())
  919. .ATTR(use_locking, Bool, false)
  920. .ATTR(use_nesterov, Bool, false)
  921. .OP_END_FACTORY_REG(ApplyAdam)
  922. /**
  923. *@brief Updates "var" according to the proximal adadelta scheme.
  924. *@par Inputs:
  925. *Seven inputs, including:
  926. * @li var: A mutable Tensor of type TensorType::NumberType().
  927. * Should be a Variable Tensor.
  928. * @li accum: A mutable Tensor of the same type as "var".
  929. * Should be a Variable Tensor.
  930. * @li accum_update: A mutable Tensor of the same type as "var".
  931. * Should be a Variable Tensor.
  932. * @li lr: A scalar of the same type as "var", for the scaling factor.
  933. * @li rho: A scalar of the same type as "var", for the decay factor.
  934. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  935. * @li grad: A Tensor of the same type as "var", for the gradient.
  936. *@par Attributes:
  937. *use_locking: An optional bool. Defaults to "False".
  938. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  939. * protected by a lock; otherwise the behavior is undefined,
  940. * but may exhibit less contention.
  941. *@par Outputs:
  942. *var: A mutable Tensor. Has the same type as "var".
  943. */
  944. REG_OP(ApplyAdadelta)
  945. .INPUT(var, TensorType::NumberType())
  946. .INPUT(accum, TensorType::NumberType())
  947. .INPUT(accum_update, TensorType::NumberType())
  948. .INPUT(lr, TensorType::NumberType())
  949. .INPUT(rho, TensorType::NumberType())
  950. .INPUT(epsilon, TensorType::NumberType())
  951. .INPUT(grad, TensorType::NumberType())
  952. .OUTPUT(var, TensorType::NumberType())
  953. .ATTR(use_locking, Bool, false)
  954. .OP_END_FACTORY_REG(ApplyAdadelta)
  955. /**
  956. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  957. * accum = accum * momentum + x1 * x2 \n
  958. * if use_nesterov is True: \n
  959. * var -= x1 * x2 * lr + accum * momentum * lr \n
  960. * else:\n
  961. * var -= accum * lr
  962. *
  963. * @par Inputs:
  964. * Six inputs, including:
  965. * @li var: A mutable Tensor has type TensorType::NumberType().
  966. * Should be a Variable Tensor.
  967. * @li accum: A mutable Tensor has the same type as "var".
  968. * Should be a Variable Tensor.
  969. * @li lr: A scalar has the same type as "var", for the scaling factor.
  970. * @li x1: A Tensor has type TensorType::NumberType().
  971. * @li momentum: A scalar has the same type as "var".
  972. * @li x2: A scalar has the same type as "var".
  973. *
  974. * @par Attributes:
  975. * Two attributes, including:
  976. * @li use_nesterov: An optional bool. Defaults to "False". \n
  977. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  978. * so in the end, the var you get is actually var - lr * momentum * accum.
  979. * @li use_locking: An optional bool. Defaults to "False". \n
  980. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  981. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  982. *
  983. * @par Outputs:
  984. * Two outputs, including:
  985. * @li var: A mutable Tensor has the same type as "var".
  986. * @li accum: A mutable Tensor has the same type as "var".
  987. */
  988. REG_OP(FusedMulApplyMomentum)
  989. .INPUT(var, TensorType::NumberType())
  990. .INPUT(accum, TensorType::NumberType())
  991. .INPUT(lr, TensorType::NumberType())
  992. .INPUT(x1, TensorType::NumberType())
  993. .INPUT(momentum, TensorType::NumberType())
  994. .INPUT(x2, TensorType::NumberType())
  995. .OUTPUT(var, TensorType::NumberType())
  996. .OUTPUT(accum, TensorType::NumberType())
  997. .ATTR(use_nesterov, Bool, false)
  998. .ATTR(use_locking, Bool, false)
  999. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  1000. /**
  1001. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1002. * accum = accum * momentum + x1 * x2 \n
  1003. * if use_nesterov is True: \n
  1004. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1005. * else: \n
  1006. * var -= accum * lr
  1007. *
  1008. * @par Inputs:
  1009. * Seven inputs, including:
  1010. * @li var: A mutable Tensor of type float32.
  1011. * Should be a Variable Tensor.
  1012. * @li accum: A mutable Tensor has type TensorType::NumberType().
  1013. * Should be a Variable Tensor.
  1014. * @li lr: A scalar has the same type as "accum", for the scaling factor.
  1015. * @li x1: A Tensor has the same type as "accum".
  1016. * @li momentum: A scalar has the same type as "accum".
  1017. * @li x2: A scalar has the same type as "accum".
  1018. * @li var_copy: A Tensor has type float16.
  1019. *
  1020. * @par Attributes:
  1021. * Two Attributes, including:
  1022. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1023. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1024. * so in the end, the var you get is actually var - lr * momentum * accum.
  1025. * @li use_locking: An optional bool. Defaults to "False". \n
  1026. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1027. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1028. *
  1029. * @par Outputs:
  1030. * Three outputs, including:
  1031. * @li var: A Tensor has the type float32.
  1032. * @li var_copy: A Tensor has the type float16.
  1033. * @li accum: A Tensor has the same type as input "accum".
  1034. */
  1035. REG_OP(FusedMulApplyMomentumExtern)
  1036. .INPUT(var, TensorType(DT_FLOAT))
  1037. .INPUT(accum, TensorType::NumberType())
  1038. .INPUT(lr, TensorType::NumberType())
  1039. .INPUT(x1, TensorType::NumberType())
  1040. .INPUT(momentum, TensorType::NumberType())
  1041. .INPUT(x2, TensorType::NumberType())
  1042. .INPUT(var_copy, TensorType(DT_FLOAT16))
  1043. .OUTPUT(var, TensorType(DT_FLOAT))
  1044. .OUTPUT(var_copy, TensorType(DT_FLOAT16))
  1045. .OUTPUT(accum, TensorType::NumberType())
  1046. .ATTR(use_nesterov, Bool, false)
  1047. .ATTR(use_locking, Bool, false)
  1048. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  1049. /**
  1050. *@brief Update "g" according to the LARS algorithm.
  1051. *@par Inputs:
  1052. *Four inputs, including:
  1053. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1054. * @li g: A Tensor of the same type and shape as "w".
  1055. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1056. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1057. *@par Attributes:
  1058. *Three Attributes, including:
  1059. * @li hyperpara: An optional float. Default value is 0.001.
  1060. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1061. * @li use_clip: An optional bool. Defaults to "False".\n
  1062. * If "True", updating learning rate.
  1063. *@par Outputs:
  1064. *g_new: Tensor of the same type as "w".
  1065. */
  1066. REG_OP(LarsV2)
  1067. .INPUT(w, TensorType(DT_FLOAT))
  1068. .INPUT(g, TensorType(DT_FLOAT))
  1069. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1070. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1071. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1072. .ATTR(hyperpara, Float, 0.001)
  1073. .ATTR(epsilon, Float, 0.00001)
  1074. .ATTR(use_clip, Bool, false)
  1075. .OP_END_FACTORY_REG(LarsV2)
  1076. /**
  1077. *@brief Update "g" according to the LARS algorithm.
  1078. *@par Inputs:
  1079. *Six inputs, including:
  1080. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1081. * @li g: A Tensor of the same type and shape as "w".
  1082. * @li w_square_sum: A Tensor of square_sum(w), has the same type as "w", Must be a scalar.
  1083. * @li g_square_sum: A Tensor of square(g), has the same type as "w", Must be a scalar.
  1084. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1085. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1086. *@par Attributes:
  1087. *Three Attributes, including:
  1088. * @li hyperpara: An optional float. Default value is 0.001.
  1089. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1090. * @li use_clip: An optional bool. Defaults to "False".\n
  1091. * If "True", updating learning rate.
  1092. *@par Outputs:
  1093. *g_new: Tensor of the same type as "w".
  1094. */
  1095. REG_OP(LarsV2Update)
  1096. .INPUT(w, TensorType(DT_FLOAT))
  1097. .INPUT(g, TensorType(DT_FLOAT))
  1098. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  1099. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  1100. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1101. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1102. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1103. .ATTR(hyperpara, Float, 0.001)
  1104. .ATTR(epsilon, Float, 0.00001)
  1105. .ATTR(use_clip, Bool, false)
  1106. .OP_END_FACTORY_REG(LarsV2Update)
  1107. /**
  1108. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1109. * @par Inputs:
  1110. * Nine inputs, including:
  1111. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1112. * Should be a Variable Tensor.
  1113. * @li accum: A mutable Tensor of the same type as "var".
  1114. * Should be a Variable Tensor.
  1115. * @li linear: A mutable Tensor of the same type as "var".
  1116. * Should be a Variable Tensor.
  1117. * @li grad: A Tensor of the same type as "var", for the gradient.
  1118. * @li indices: A vector of indices into the first dimension of var and accum.
  1119. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1120. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1121. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1122. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1123. * @par Attributes:
  1124. * use_locking: An optional bool. Defaults to "False".
  1125. * If "True", updating of the "var" and "accum" tensors will be
  1126. * protected by a lock; otherwise the behavior is undefined,
  1127. * but may exhibit less contention.
  1128. * @par Outputs:
  1129. * var: A Tensor. Has the same type and format as input "var".
  1130. */
  1131. REG_OP(SparseApplyFtrl)
  1132. .INPUT(var, TensorType({DT_FLOAT}))
  1133. .INPUT(accum, TensorType({DT_FLOAT}))
  1134. .INPUT(linear, TensorType({DT_FLOAT}))
  1135. .INPUT(grad, TensorType({DT_FLOAT}))
  1136. .INPUT(indices, TensorType({DT_INT32}))
  1137. .INPUT(lr, TensorType({DT_FLOAT}))
  1138. .INPUT(l1, TensorType({DT_FLOAT}))
  1139. .INPUT(l2, TensorType({DT_FLOAT}))
  1140. .INPUT(lr_power, TensorType({DT_FLOAT}))
  1141. .OUTPUT(var, TensorType({DT_FLOAT}))
  1142. .ATTR(use_locking, Bool, false)
  1143. .OP_END_FACTORY_REG(SparseApplyFtrl)
  1144. /**
  1145. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1146. * @par Inputs:
  1147. * Nine inputs, including:
  1148. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1149. * Should be a Variable Tensor.
  1150. * @li accum: A mutable Tensor of the same type as "var".
  1151. * Should be a Variable Tensor.
  1152. * @li linear: A mutable Tensor of the same type as "var".
  1153. * Should be a Variable Tensor.
  1154. * @li grad: A Tensor of the same type as "var", for the gradient.
  1155. * @li indices: A vector of indices into the first dimension of var and accum.
  1156. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1157. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1158. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1159. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1160. * @par Attributes:
  1161. * use_locking: An optional bool. Defaults to "False".
  1162. * If "True", updating of the "var" and "accum" tensors will be
  1163. * protected by a lock; otherwise the behavior is undefined,
  1164. * but may exhibit less contention.
  1165. * @par Outputs:
  1166. * var: A Tensor. Has the same type and format as input "var".
  1167. * accum: A Tensor. Has the same type and format as input "accum".
  1168. * linear: A Tensor. Has the same type and format as input "linear".
  1169. */
  1170. REG_OP(SparseApplyFtrlD)
  1171. .INPUT(var, TensorType({DT_FLOAT}))
  1172. .INPUT(accum, TensorType({DT_FLOAT}))
  1173. .INPUT(linear, TensorType({DT_FLOAT}))
  1174. .INPUT(grad, TensorType({DT_FLOAT}))
  1175. .INPUT(indices, TensorType({DT_INT32}))
  1176. .OUTPUT(var, TensorType({DT_FLOAT}))
  1177. .OUTPUT(accum, TensorType({DT_FLOAT}))
  1178. .OUTPUT(linear, TensorType({DT_FLOAT}))
  1179. .REQUIRED_ATTR(lr, Float)
  1180. .REQUIRED_ATTR(l1, Float)
  1181. .REQUIRED_ATTR(l2, Float)
  1182. .REQUIRED_ATTR(lr_power, Float)
  1183. .ATTR(use_locking, Bool, false)
  1184. .OP_END_FACTORY_REG(SparseApplyFtrlD)
  1185. /**
  1186. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1187. * That is for rows we have grad for, we update var, accum and linear
  1188. * @par Inputs:
  1189. * Ten inputs, including:
  1190. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1191. * Should be a Variable Tensor.
  1192. * @li accum: A mutable Tensor of the same type as "var".
  1193. * Should be a Variable Tensor.
  1194. * @li linear: A mutable Tensor of the same type as "var".
  1195. * Should be a Variable Tensor.
  1196. * @li grad: A Tensor of the same type as "var", for the gradient.
  1197. * @li indices: A vector of indices into the first dimension of var and accum.
  1198. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1199. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1200. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1201. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  1202. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1203. * @par Attributes:
  1204. * use_locking: An optional bool. Defaults to "False".
  1205. * If "True", updating of the "var" and "accum" tensors will be
  1206. * rotected by a lock; otherwise the behavior is undefined,
  1207. * but may exhibit less contention.
  1208. * @par Outputs:
  1209. * var: A Tensor. Has the same type and format as input "var".
  1210. */
  1211. REG_OP(SparseApplyFtrlV2)
  1212. .INPUT(var, TensorType({DT_FLOAT}))
  1213. .INPUT(accum, TensorType({DT_FLOAT}))
  1214. .INPUT(linear, TensorType({DT_FLOAT}))
  1215. .INPUT(grad, TensorType({DT_FLOAT}))
  1216. .INPUT(indices, TensorType({DT_INT32}))
  1217. .INPUT(lr, TensorType({DT_FLOAT}))
  1218. .INPUT(l1, TensorType({DT_FLOAT}))
  1219. .INPUT(l2, TensorType({DT_FLOAT}))
  1220. .INPUT(l2_shrinkage, TensorType({DT_FLOAT}))
  1221. .INPUT(lr_power, TensorType({DT_FLOAT}))
  1222. .OUTPUT(var, TensorType({DT_FLOAT}))
  1223. .ATTR(use_locking, Bool, false)
  1224. .OP_END_FACTORY_REG(SparseApplyFtrlV2)
  1225. /**
  1226. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1227. * That is for rows we have grad for, we update var, accum and linear
  1228. * @par Inputs:
  1229. * Ten inputs, including:
  1230. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1231. * Should be a Variable Tensor.
  1232. * @li accum: A mutable Tensor of the same type as "var".
  1233. * Should be a Variable Tensor.
  1234. * @li linear: A mutable Tensor of the same type as "var".
  1235. * Should be a Variable Tensor.
  1236. * @li grad: A Tensor of the same type as "var", for the gradient.
  1237. * @li indices: A vector of indices into the first dimension of var and accum.
  1238. * @par Attributes:
  1239. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1240. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1241. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1242. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  1243. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1244. * @li use_locking: An optional bool. Defaults to "False".
  1245. * If "True", updating of the "var" and "accum" tensors will be
  1246. * rotected by a lock; otherwise the behavior is undefined,
  1247. * but may exhibit less contention.
  1248. * @par Outputs:
  1249. * var: A Tensor. Has the same type and format as input "var".
  1250. * accum: A Tensor. Has the same type and format as input "accum".
  1251. * linear: A Tensor. Has the same type and format as input "linear".
  1252. */
  1253. REG_OP(SparseApplyFtrlV2D)
  1254. .INPUT(var, TensorType({DT_FLOAT}))
  1255. .INPUT(accum, TensorType({DT_FLOAT}))
  1256. .INPUT(linear, TensorType({DT_FLOAT}))
  1257. .INPUT(grad, TensorType({DT_FLOAT}))
  1258. .INPUT(indices, TensorType({DT_INT32}))
  1259. .OUTPUT(var, TensorType({DT_FLOAT}))
  1260. .OUTPUT(accum, TensorType({DT_FLOAT}))
  1261. .OUTPUT(linear, TensorType({DT_FLOAT}))
  1262. .REQUIRED_ATTR(lr, Float)
  1263. .REQUIRED_ATTR(l1, Float)
  1264. .REQUIRED_ATTR(l2, Float)
  1265. .REQUIRED_ATTR(l2_shrinkage, Float)
  1266. .REQUIRED_ATTR(lr_power, Float)
  1267. .ATTR(use_locking, Bool, false)
  1268. .OP_END_FACTORY_REG(SparseApplyFtrlV2D)
  1269. /**
  1270. *@brief Clean memory of workspace list.
  1271. *@par Attributes:
  1272. * @li automic_add_mem_size: sizes of workspaces.
  1273. */
  1274. REG_OP(AtomicAddrClean)
  1275. .ATTR(automic_add_mem_size, ListInt, {})
  1276. .OP_END_FACTORY_REG(AtomicAddrClean)
  1277. } // namespace ge
  1278. #endif // GE_OP_TRAINING_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示