You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 101 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_TRAINING_OPS_H
  17. #define GE_OP_TRAINING_OPS_H
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Updates "var" according to the AdaMax algorithm.
  22. * t-1 mean previous period.
  23. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  24. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  25. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  26. *
  27. *@attention Constraints:
  28. * the input tensors must have the same shape.
  29. *
  30. *@par Inputs:
  31. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  32. * Should be from a Variable().
  33. *@li m: A mutable tensor. Has the same type as "var".
  34. * Should be from a Variable().
  35. *@li v: A mutable tensor. Has the same type as "var".
  36. * Should be from a Variable().
  37. *@li beta1_power: A scalar. Has the same type as "var".
  38. *@li lr: learning_rate. A scalar. Has the same type as "var".
  39. *@li beta1: A scalar. Has the same type as "var".
  40. *@li beta2: A scalar. Has the same type as "var".
  41. *@li epsilon: A scalar. Has the same type as "var".
  42. *@li grad: A tensor for the gradient. Has the same type as "var".
  43. *
  44. *@par Attributes:
  45. * use_locking: An optional bool. Defaults to "False".
  46. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  47. * by a lock; otherwise the behavior is undefined, but may exhibit less
  48. * contention.
  49. *
  50. *@par Outputs:
  51. * var: A mutable tensor. Has the same type as input "var".
  52. *
  53. *@par Third-party framework compatibility
  54. *Compatible with the TensorFlow operator ApplyAdaMax.
  55. *
  56. */
  57. REG_OP(ApplyAdaMax)
  58. .INPUT(var, TensorType::NumberType())
  59. .INPUT(m, TensorType::NumberType())
  60. .INPUT(v, TensorType::NumberType())
  61. .INPUT(beta1_power, TensorType::NumberType())
  62. .INPUT(lr, TensorType::NumberType())
  63. .INPUT(beta1, TensorType::NumberType())
  64. .INPUT(beta2, TensorType::NumberType())
  65. .INPUT(epsilon, TensorType::NumberType())
  66. .INPUT(grad, TensorType::NumberType())
  67. .OUTPUT(var, TensorType::NumberType())
  68. .ATTR(use_locking, Bool, false)
  69. .OP_END_FACTORY_REG(ApplyAdaMax)
  70. /**
  71. *@brief Updates "var" according to the AdaMax algorithm.
  72. * t-1 mean previous period.
  73. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  74. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  75. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  76. *
  77. *@attention Constraints:
  78. * the input tensors must have the same shape.
  79. *
  80. *@par Inputs:
  81. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  82. * Should be from a Variable().
  83. *@li m: A mutable tensor. Has the same type as "var".
  84. * Should be from a Variable().
  85. *@li v: A mutable tensor. Has the same type as "var".
  86. * Should be from a Variable().
  87. *@li beta1_power: A scalar. Has the same type as "var".
  88. *@li lr: learning_rate. A scalar. Has the same type as "var".
  89. *@li beta1: A scalar. Has the same type as "var".
  90. *@li beta2: A scalar. Has the same type as "var".
  91. *@li epsilon: A scalar. Has the same type as "var".
  92. *@li grad: A tensor for the gradient. Has the same type as "var".
  93. *
  94. *@par Attributes:
  95. * use_locking: An optional bool. Defaults to "False".
  96. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  97. * by a lock; otherwise the behavior is undefined, but may exhibit less
  98. * contention.
  99. *
  100. *@par Outputs:
  101. *@li var: A mutable tensor. Has the same type as input "var".
  102. *@li m: A mutable tensor. Has the same type as input "m".
  103. *@li v: A mutable tensor. Has the same type as input "v".
  104. *
  105. *@par Third-party framework compatibility
  106. *Compatible with the TensorFlow operator ApplyAdaMax.
  107. *
  108. */
  109. REG_OP(ApplyAdaMaxD)
  110. .INPUT(var, TensorType::NumberType())
  111. .INPUT(m, TensorType::NumberType())
  112. .INPUT(v, TensorType::NumberType())
  113. .INPUT(beta1_power, TensorType::NumberType())
  114. .INPUT(lr, TensorType::NumberType())
  115. .INPUT(beta1, TensorType::NumberType())
  116. .INPUT(beta2, TensorType::NumberType())
  117. .INPUT(epsilon, TensorType::NumberType())
  118. .INPUT(grad, TensorType::NumberType())
  119. .OUTPUT(var, TensorType::NumberType())
  120. .OUTPUT(m, TensorType::NumberType())
  121. .OUTPUT(v, TensorType::NumberType())
  122. .ATTR(use_locking, Bool, false)
  123. .OP_END_FACTORY_REG(ApplyAdaMaxD)
  124. /**
  125. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  126. *@par Inputs:
  127. * Five inputs, including:
  128. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  129. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  130. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  131. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  132. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  133. *@par Attributes:
  134. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  135. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  136. *@par Outputs:
  137. *var: A Tensor. Has the same type and format as input "var".
  138. *@par Third-party framework compatibility
  139. * Compatible with the TensorFlow operator SparseApplyAdagrad.
  140. */
  141. REG_OP(SparseApplyAdagrad)
  142. .INPUT(var, TensorType({DT_FLOAT}))
  143. .INPUT(accum, TensorType({DT_FLOAT}))
  144. .INPUT(lr, TensorType({DT_FLOAT}))
  145. .INPUT(grad, TensorType({DT_FLOAT}))
  146. .INPUT(indices, TensorType({DT_INT32}))
  147. .OUTPUT(var, TensorType({DT_FLOAT}))
  148. .ATTR(use_locking, Bool, false)
  149. .ATTR(update_slots, Bool, true)
  150. .OP_END_FACTORY_REG(SparseApplyAdagrad)
  151. /**
  152. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  153. *@par Inputs:
  154. * Four inputs, including:
  155. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  156. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  157. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  158. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  159. *@par Attributes:
  160. *@li lr: Required, used for computation.
  161. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  162. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  163. *@par Outputs:
  164. *@li var: A Tensor. Has the same type and format as input "var".
  165. *@li accum: A Tensor. Has the same type and format as input "var".
  166. *@par Third-party framework compatibility
  167. * Compatible with the TensorFlow operator SparseApplyAdagrad.
  168. */
  169. REG_OP(SparseApplyAdagradD)
  170. .INPUT(var, TensorType({DT_FLOAT}))
  171. .INPUT(accum, TensorType({DT_FLOAT}))
  172. .INPUT(grad, TensorType({DT_FLOAT}))
  173. .INPUT(indices, TensorType({DT_INT32}))
  174. .OUTPUT(var, TensorType({DT_FLOAT}))
  175. .OUTPUT(accum, TensorType({DT_FLOAT}))
  176. .REQUIRED_ATTR(lr, Float)
  177. .ATTR(use_locking, Bool, false)
  178. .ATTR(update_slots, Bool, true)
  179. .OP_END_FACTORY_REG(SparseApplyAdagradD)
  180. /**
  181. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  182. *@par Inputs:
  183. *Six inputs, including:
  184. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  185. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  186. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  187. *@li epsilon: An NCHW, NHWC, or ND Tensor of type float32.
  188. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  189. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  190. *@par Attributes:
  191. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  192. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different.
  193. *@par Outputs:
  194. *var: A Tensor. Has the same type and format as input "var".
  195. *@par Third-party framework compatibility
  196. *Compatible with the TensorFlow operator SparseApplyAdagradV2.
  197. */
  198. REG_OP(SparseApplyAdagradV2)
  199. .INPUT(var, TensorType({DT_FLOAT}))
  200. .INPUT(accum, TensorType({DT_FLOAT}))
  201. .INPUT(lr, TensorType({DT_FLOAT}))
  202. .INPUT(epsilon, TensorType({DT_FLOAT}))
  203. .INPUT(grad, TensorType({DT_FLOAT}))
  204. .INPUT(indices, TensorType({DT_INT32}))
  205. .OUTPUT(var, TensorType({DT_FLOAT}))
  206. .ATTR(use_locking, Bool, false)
  207. .ATTR(update_slots, Bool, true)
  208. .OP_END_FACTORY_REG(SparseApplyAdagradV2)
  209. /**
  210. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  211. *@par Inputs:
  212. *Four inputs, including:
  213. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  214. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  215. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  216. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  217. *@par Attributes:
  218. *@li lr: Required, used for computation.
  219. *@li epsilon: Required, used for computation.
  220. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  221. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different.
  222. *@par Outputs:
  223. *@li var: A Tensor. Has the same type and format as input "var".
  224. *@li accum: A Tensor. Has the same type and format as input "accum".
  225. *@par Third-party framework compatibility
  226. *Compatible with the TensorFlow operator SparseApplyAdagradV2.
  227. */
  228. REG_OP(SparseApplyAdagradV2D)
  229. .INPUT(var, TensorType({DT_FLOAT}))
  230. .INPUT(accum, TensorType({DT_FLOAT}))
  231. .INPUT(grad, TensorType({DT_FLOAT}))
  232. .INPUT(indices, TensorType({DT_INT32}))
  233. .OUTPUT(var, TensorType({DT_FLOAT}))
  234. .OUTPUT(accum, TensorType({DT_FLOAT}))
  235. .REQUIRED_ATTR(lr, Float)
  236. .REQUIRED_ATTR(epsilon, Float)
  237. .ATTR(use_locking, Bool, false)
  238. .ATTR(update_slots, Bool, true)
  239. .OP_END_FACTORY_REG(SparseApplyAdagradV2D)
  240. /**
  241. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  242. * want to use Nesterov momentum.
  243. * computing process: \n
  244. * accum = accum * momentum + grad\n
  245. * var -= lr * accum
  246. *
  247. *@attention Constraints:
  248. * the input tensors must have the same shape.
  249. *
  250. *@par Inputs:
  251. *@li var: A mutable tensor. Should be from a Variable().
  252. *@li accum: A mutable tensor. Has the same type as "var".
  253. * Should be from a Variable().
  254. *@li lr: A scalar. Has the same type as "var".
  255. *@li grad: A tensor for the gradient. Has the same type as "var".
  256. *
  257. *@par Attributes:
  258. *@li use_nesterov: An optional bool. Defaults to "False".
  259. * If "True", the tensor passed to compute grad will be
  260. * var - lr * momentum * accum, so in the end, the var you get is actually
  261. * var - lr * momentum * accum.
  262. *
  263. *@li use_locking: An optional bool. Defaults to "False".
  264. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  265. * otherwise the behavior is undefined, but may exhibit less contention.
  266. *
  267. *@par Outputs:
  268. * var: A mutable tensor. Has the same type as input "var".
  269. *
  270. *@par Third-party framework compatibility
  271. *Compatible with the TensorFlow operator ApplyMomentum.
  272. *
  273. */
  274. REG_OP(ApplyMomentum)
  275. .INPUT(var, TensorType::NumberType())
  276. .INPUT(accum, TensorType::NumberType())
  277. .INPUT(lr, TensorType::NumberType())
  278. .INPUT(grad, TensorType::NumberType())
  279. .INPUT(momentum, TensorType::NumberType())
  280. .OUTPUT(var, TensorType::NumberType())
  281. .ATTR(use_nesterov, Bool, false)
  282. .ATTR(use_locking, Bool, false)
  283. .OP_END_FACTORY_REG(ApplyMomentum)
  284. /**
  285. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  286. * want to use Nesterov momentum.
  287. * computing process: \n
  288. * accum = accum * momentum + grad\n
  289. * var -= lr * accum
  290. *
  291. *@attention Constraints:
  292. * the input tensors must have the same shape.
  293. *
  294. *@par Inputs:
  295. *@li var: A mutable tensor. Should be from a Variable().
  296. *@li accum: A mutable tensor. Has the same type as "var".
  297. * Should be from a Variable().
  298. *@li lr: A scalar. Has the same type as "var".
  299. *@li grad: A tensor for the gradient. Has the same type as "var".
  300. *
  301. *@par Attributes:
  302. *@li use_nesterov: An optional bool. Defaults to "False".
  303. * If "True", the tensor passed to compute grad will be
  304. * var - lr * momentum * accum, so in the end, the var you get is actually
  305. * var - lr * momentum * accum.
  306. *
  307. *@li use_locking: An optional bool. Defaults to "False".
  308. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  309. * otherwise the behavior is undefined, but may exhibit less contention.
  310. *
  311. *@par Outputs:
  312. * var: A mutable tensor. Has the same type as input "var".
  313. * accum: A mutable tensor. Has the same type as input "accum".
  314. *@par Third-party framework compatibility
  315. *Compatible with the TensorFlow operator ApplyMomentum.
  316. *
  317. */
  318. REG_OP(ApplyMomentumD)
  319. .INPUT(var, TensorType::NumberType())
  320. .INPUT(accum, TensorType::NumberType())
  321. .INPUT(lr, TensorType::NumberType())
  322. .INPUT(grad, TensorType::NumberType())
  323. .INPUT(momentum, TensorType::NumberType())
  324. .OUTPUT(var, TensorType::NumberType())
  325. .OUTPUT(accum, TensorType::NumberType())
  326. .ATTR(use_nesterov, Bool, false)
  327. .ATTR(use_locking, Bool, false)
  328. .OP_END_FACTORY_REG(ApplyMomentumD)
  329. /**
  330. *@brief Updates '*var' according to the momentum scheme.
  331. * accum = accum * momentum - grad * lr \n
  332. * if use_nesterov is True: \n
  333. * var += accum * momentum - grad * lr \n
  334. * else: \n
  335. * var += accum
  336. *
  337. *@par Inputs:
  338. *@li var: A mutable tensor. Must be one of the data types defined in
  339. * TensorType::NumberType(). Should be from a Variable().
  340. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  341. * Variable().
  342. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  343. * from a Variable().
  344. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  345. * from a Variable().
  346. *@li momentum: A scalar. Has the same type as "var".
  347. *
  348. *@par Attributes:
  349. *@li use_nesterov: An optional bool. Defaults to "False".
  350. * If "True", var will be updated by using Nesterov momentum.
  351. *@li use_locking: An optional bool. Defaults to "False".
  352. * If "True", updating of the "var" tensor is protected by a lock;
  353. * otherwise the behavior is undefined, but may exhibit less contention.
  354. *
  355. *@par Outputs:
  356. * var: A mutable tensor. Has the same type as input "var".
  357. *
  358. *@attention Constraints:
  359. * The input tensors must have the same shape.
  360. *
  361. *@par Third-party framework compatibility
  362. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  363. *
  364. */
  365. REG_OP(ApplyKerasMomentum)
  366. .INPUT(var, TensorType::NumberType())
  367. .INPUT(accum, TensorType::NumberType())
  368. .INPUT(lr, TensorType::NumberType())
  369. .INPUT(grad, TensorType::NumberType())
  370. .INPUT(momentum, TensorType::NumberType())
  371. .OUTPUT(var, TensorType::NumberType())
  372. .ATTR(use_locking, Bool, false)
  373. .ATTR(use_nesterov, Bool, false)
  374. .OP_END_FACTORY_REG(ApplyKerasMomentum)
  375. /**
  376. *@brief Updates '*var' according to the momentum scheme.
  377. * accum = accum * momentum - grad * lr \n
  378. * if use_nesterov is True: \n
  379. * var += accum * momentum - grad * lr \n
  380. * else: \n
  381. * var += accum
  382. *
  383. *@par Inputs:
  384. *@li var: A mutable tensor. Must be one of the data types defined in
  385. * TensorType::NumberType(). Should be from a Variable().
  386. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  387. * Variable().
  388. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  389. * from a Variable().
  390. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  391. * from a Variable().
  392. *@li momentum: A scalar. Has the same type as "var". Should be from a
  393. * Variable().
  394. *
  395. *@par Attributes:
  396. *@li use_nesterov: An optional bool. Defaults to "False".
  397. * If "True", var will be updated by using nesterov momentum
  398. *@li use_locking: An optional bool. Defaults to "False".
  399. * If "True", updating of the "var" tensor is protected by a lock;
  400. * otherwise the behavior is undefined, but may exhibit less contention.
  401. *
  402. *@par Outputs:
  403. *@li var: A mutable tensor. Has the same type as input "var".
  404. *@li accum: A mutable tensor. Has the same type as input "var"
  405. *
  406. *@attention Constraints:
  407. * The input tensors must have the same shape.
  408. *
  409. *@par Third-party framework compatibility
  410. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  411. *
  412. */
  413. REG_OP(ApplyKerasMomentumD)
  414. .INPUT(var, TensorType::NumberType())
  415. .INPUT(accum, TensorType::NumberType())
  416. .INPUT(lr, TensorType::NumberType())
  417. .INPUT(grad, TensorType::NumberType())
  418. .INPUT(momentum, TensorType::NumberType())
  419. .OUTPUT(var, TensorType::NumberType())
  420. .OUTPUT(accum, TensorType::NumberType())
  421. .ATTR(use_locking, Bool, false)
  422. .ATTR(use_nesterov, Bool, false)
  423. .OP_END_FACTORY_REG(ApplyKerasMomentumD)
  424. /**
  425. *@brief Updates '*var' according to the Adam algorithm.
  426. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  427. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  428. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  429. * vhat_t := max{vhat_{t-1}, v_t}
  430. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  431. *
  432. *@par Inputs:
  433. *@li var: A mutable tensor. Must be one of the data types defined in
  434. * TensorType::NumberType(). Should be from a Variable().
  435. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  436. * Variable().
  437. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  438. * Variable().
  439. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  440. * Variable().
  441. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  442. * Variable().
  443. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  444. * Variable().
  445. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  446. * from a Variable().
  447. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  448. * from a Variable().
  449. *
  450. *@par Attributes:
  451. *@li beta1: A scalar. Has the same type as "var".
  452. *@li beta2: A scalar. Has the same type as "var".
  453. *@li epsilon: A scalar. Has the same type as "var".
  454. *@li use_locking: An optional bool. Defaults to "False".
  455. * If "True", updating of the "var" tensor is protected by a lock;
  456. * otherwise the behavior is undefined, but may exhibit less contention.
  457. *
  458. *@par Outputs:
  459. *@li var: A mutable tensor. Has the same type as input "var".
  460. *@li m: A mutable tensor. Has the same type as input "var"
  461. *@li v: A mutable tensor. Has the same type as input "var"
  462. *@li vhat: A mutable tensor. Has the same type as input "var"
  463. *
  464. *@attention Constraints:
  465. * The input tensors must have the same shape.
  466. *
  467. *@par Third-party framework compatibility
  468. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  469. *
  470. */
  471. REG_OP(ApplyAdamWithAmsgradD)
  472. .INPUT(var, TensorType::NumberType())
  473. .INPUT(m, TensorType::NumberType())
  474. .INPUT(v, TensorType::NumberType())
  475. .INPUT(vhat, TensorType::NumberType())
  476. .INPUT(beta1_power, TensorType::NumberType())
  477. .INPUT(beta2_power, TensorType::NumberType())
  478. .INPUT(lr, TensorType::NumberType())
  479. .INPUT(grad, TensorType::NumberType())
  480. .OUTPUT(var, TensorType::NumberType())
  481. .OUTPUT(m, TensorType::NumberType())
  482. .OUTPUT(v, TensorType::NumberType())
  483. .OUTPUT(vhat, TensorType::NumberType())
  484. .REQUIRED_ATTR(beta1, Float)
  485. .REQUIRED_ATTR(beta2, Float)
  486. .REQUIRED_ATTR(epsilon, Float)
  487. .ATTR(use_locking, Bool, false)
  488. .OP_END_FACTORY_REG(ApplyAdamWithAmsgradD)
  489. /**
  490. *@brief Updates '*var' according to the Adam algorithm..
  491. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  492. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  493. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  494. * vhat_t := max{vhat_{t-1}, v_t}
  495. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  496. *
  497. *@par Inputs:
  498. *@li var: A mutable tensor. Must be one of the data types defined in
  499. * TensorType::NumberType(). Should be from a Variable().
  500. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  501. * Variable().
  502. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  503. * Variable().
  504. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  505. * Variable().
  506. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  507. * Variable().
  508. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  509. * Variable().
  510. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  511. * from a Variable().
  512. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  513. * from a Variable().
  514. *
  515. *@par Attributes:
  516. *@li beta1: A scalar. Has the same type as "var".
  517. *@li beta2: A scalar. Has the same type as "var".
  518. *@li epsilon: A scalar. Has the same type as "var".
  519. *@li use_locking: An optional bool. Defaults to "False".
  520. * If "True", updating of the "var" tensor is protected by a lock;
  521. * otherwise the behavior is undefined, but may exhibit less contention.
  522. *
  523. *@par Outputs:
  524. *@li var: A mutable tensor. Has the same type as input "var".
  525. *@li m: A mutable tensor. Has the same type as input "var"
  526. *@li v: A mutable tensor. Has the same type as input "var"
  527. *@li vhat: A mutable tensor. Has the same type as input "var"
  528. *
  529. *@attention Constraints:
  530. * The input tensors must have the same shape.
  531. *
  532. *@par Third-party framework compatibility
  533. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  534. *
  535. */
  536. REG_OP(ApplyAdamWithAmsgrad)
  537. .INPUT(var, TensorType::NumberType())
  538. .INPUT(m, TensorType::NumberType())
  539. .INPUT(v, TensorType::NumberType())
  540. .INPUT(vhat, TensorType::NumberType())
  541. .INPUT(beta1_power, TensorType::NumberType())
  542. .INPUT(beta2_power, TensorType::NumberType())
  543. .INPUT(lr, TensorType::NumberType())
  544. .INPUT(beta1, TensorType::NumberType())
  545. .INPUT(beta2, TensorType::NumberType())
  546. .INPUT(epsilon, TensorType::NumberType())
  547. .INPUT(grad, TensorType::NumberType())
  548. .OUTPUT(var, TensorType::NumberType())
  549. .ATTR(use_locking, Bool, false)
  550. .OP_END_FACTORY_REG(ApplyAdamWithAmsgrad)
  551. /**
  552. *@brief Updates "var" according to the AddSign update.
  553. * t-1 mean previous period.
  554. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  555. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  556. * var <- var - lr * update
  557. *
  558. *@attention Constraints:
  559. * the input tensors must have the same shape.
  560. *
  561. *@par Inputs:
  562. *@li var: A mutable tensor. Should be from a Variable().
  563. *@li m: A mutable tensor. Has the same type as "var".
  564. * Should be from a Variable().
  565. *@li lr: A scalar. Has the same type as "var".
  566. *@li logbase: A scalar. Has the same type as "var".
  567. *@li sign_decay: A scalar. Has the same type as "var".
  568. *@li beta: A scalar. Has the same type as "var".
  569. *@li grad: A tensor for the gradient. Has the same type as "var".
  570. *
  571. *@par Attributes:
  572. * use_locking: An optional bool. Defaults to "False".
  573. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  574. * by a lock; otherwise the behavior is undefined, but may exhibit less
  575. * contention.
  576. *
  577. *@par Outputs:
  578. * var: A mutable tensor. Has the same type as input "var".
  579. *
  580. *@par Third-party framework compatibility
  581. *Compatible with the TensorFlow operator ApplyPowerSign.
  582. *
  583. */
  584. REG_OP(ApplyPowerSign)
  585. .INPUT(var, TensorType::NumberType())
  586. .INPUT(m, TensorType::NumberType())
  587. .INPUT(lr, TensorType::NumberType())
  588. .INPUT(logbase, TensorType::NumberType())
  589. .INPUT(sign_decay, TensorType::NumberType())
  590. .INPUT(beta, TensorType::NumberType())
  591. .INPUT(grad, TensorType::NumberType())
  592. .OUTPUT(var, TensorType::NumberType())
  593. .ATTR(use_locking, Bool, false)
  594. .OP_END_FACTORY_REG(ApplyPowerSign)
  595. /**
  596. *@brief Updates "var" according to the AddSign update.
  597. * t-1 mean previous period.
  598. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  599. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  600. * var <- var - lr * update
  601. *
  602. *@attention Constraints:
  603. * the input tensors must have the same shape.
  604. *
  605. *@par Inputs:
  606. *@li var: A mutable tensor. Should be from a Variable().
  607. *@li m: A mutable tensor. Has the same type as "var".
  608. * Should be from a Variable().
  609. *@li lr: A scalar. Has the same type as "var".
  610. *@li logbase: A scalar. Has the same type as "var".
  611. *@li sign_decay: A scalar. Has the same type as "var".
  612. *@li beta: A scalar. Has the same type as "var".
  613. *@li grad: A tensor for the gradient. Has the same type as "var".
  614. *
  615. *@par Attributes:
  616. * use_locking: An optional bool. Defaults to "False".
  617. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  618. * by a lock; otherwise the behavior is undefined, but may exhibit less
  619. * contention.
  620. *
  621. *@par Outputs:
  622. *@li var: A mutable tensor. Has the same type as input "var".
  623. *@li m: A mutable tensor. Has the same type as input "var".
  624. *
  625. *@par Third-party framework compatibility
  626. *Compatible with the TensorFlow operator ApplyPowerSign.
  627. *
  628. */
  629. REG_OP(ApplyPowerSignD)
  630. .INPUT(var, TensorType::NumberType())
  631. .INPUT(m, TensorType::NumberType())
  632. .INPUT(lr, TensorType::NumberType())
  633. .INPUT(logbase, TensorType::NumberType())
  634. .INPUT(sign_decay, TensorType::NumberType())
  635. .INPUT(beta, TensorType::NumberType())
  636. .INPUT(grad, TensorType::NumberType())
  637. .OUTPUT(var, TensorType::NumberType())
  638. .OUTPUT(m, TensorType::NumberType())
  639. .ATTR(use_locking, Bool, false)
  640. .OP_END_FACTORY_REG(ApplyPowerSignD)
  641. /**
  642. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n
  643. * prox_v = var - alpha * delta\n
  644. * var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
  645. *
  646. *@attention Constraints:\n
  647. * the input tensors must have the same shape.
  648. *
  649. *@par Inputs:
  650. *@li var: A mutable tensor. Should be from a Variable().
  651. *@li alpha: A scalar. Has the same type as "var".
  652. *@li l1: A scalar. Has the same type as "var".
  653. *@li l2: A scalar. Has the same type as "var".
  654. *@li delta: A tensor. Has the same type as "var". The change.
  655. *
  656. *@par Attributes:
  657. * use_locking: An optional bool. Defaults to "False".
  658. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  659. * by a lock; otherwise the behavior is undefined, but may exhibit less
  660. * contention.
  661. *
  662. *@par Outputs:
  663. * var: A mutable tensor. Has the same type as input "var".
  664. *
  665. *@par Third-party framework compatibility
  666. *Compatible with the TensorFlow operator ApplyProximalGradientDescent.
  667. *
  668. */
  669. REG_OP(ApplyProximalGradientDescent)
  670. .INPUT(var, TensorType::NumberType())
  671. .INPUT(alpha, TensorType::NumberType())
  672. .INPUT(l1, TensorType::NumberType())
  673. .INPUT(l2, TensorType::NumberType())
  674. .INPUT(delta, TensorType::NumberType())
  675. .OUTPUT(var, TensorType::NumberType())
  676. .ATTR(use_locking, Bool, false)
  677. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  678. /**
  679. *@brief Updates "var" according to the AddSign update.
  680. *@par Inputs:
  681. *Seven inputs, including:
  682. * @li var: A mutable Tensor of type TensorType::NumberType().
  683. * Should be a Variable Tensor.
  684. * @li m: A mutable Tensor of the same type as "var".
  685. * Should be a Variable Tensor.
  686. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  687. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  688. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  689. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  690. * @li grad: A Tensor of the same type as "var", for the gradient.
  691. *@par Attributes:
  692. *use_locking: An optional bool. Defaults to "False".
  693. * If "True", updating of the "var" and "m" tensors will be
  694. * protected by a lock; otherwise the behavior is undefined,
  695. * but may exhibit less contention.
  696. *@par Outputs:
  697. *var: A mutable Tensor. Has the same type as "var".
  698. *@par Third-party framework compatibility
  699. * Compatible with the TensorFlow operator ApplyAddSign.
  700. */
  701. REG_OP(ApplyAddSign)
  702. .INPUT(var, TensorType::NumberType())
  703. .INPUT(m, TensorType::NumberType())
  704. .INPUT(lr, TensorType::NumberType())
  705. .INPUT(alpha, TensorType::NumberType())
  706. .INPUT(sign_decay, TensorType::NumberType())
  707. .INPUT(beta, TensorType::NumberType())
  708. .INPUT(grad, TensorType::NumberType())
  709. .OUTPUT(var, TensorType::NumberType())
  710. .ATTR(use_locking, Bool, false)
  711. .OP_END_FACTORY_REG(ApplyAddSign)
  712. /**
  713. *@brief Updates "var" according to the AddSign update.
  714. *@par Inputs:
  715. *Seven inputs, including:
  716. * @li var: A mutable Tensor of type TensorType::NumberType().
  717. * Should be a Variable Tensor.
  718. * @li m: A mutable Tensor of the same type as "var".
  719. * Should be a Variable Tensor.
  720. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  721. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  722. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  723. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  724. * @li grad: A Tensor of the same type as "var", for the gradient.
  725. *@par Attributes:
  726. *use_locking: An optional bool. Defaults to "False".
  727. * If "True", updating of the "var" and "m" tensors will be
  728. * protected by a lock; otherwise the behavior is undefined,
  729. * but may exhibit less contention.
  730. *@par Outputs:
  731. *@li var: A mutable Tensor. Has the same type as "var".
  732. *@li m: A mutable Tensor. Has the same type as "m".
  733. *@par Third-party framework compatibility
  734. * Compatible with the TensorFlow operator ApplyAddSign.
  735. */
  736. REG_OP(ApplyAddSignD)
  737. .INPUT(var, TensorType::NumberType())
  738. .INPUT(m, TensorType::NumberType())
  739. .INPUT(lr, TensorType::NumberType())
  740. .INPUT(alpha, TensorType::NumberType())
  741. .INPUT(sign_decay, TensorType::NumberType())
  742. .INPUT(beta, TensorType::NumberType())
  743. .INPUT(grad, TensorType::NumberType())
  744. .OUTPUT(var, TensorType::NumberType())
  745. .OUTPUT(m, TensorType::NumberType())
  746. .ATTR(use_locking, Bool, false)
  747. .OP_END_FACTORY_REG(ApplyAddSignD)
  748. /**
  749. *@brief Updates "var" according to the centered RMSProp algorithm.
  750. * The centered RMSProp algorithm uses an estimate of the centered second moment
  751. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  752. * uses the (uncentered) second moment. This often helps with training, but is
  753. * slightly more expensive in terms of computation and memory.
  754. *
  755. * t-1 mean previous period.
  756. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  757. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  758. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  759. * var <- var - mom\n
  760. *
  761. *@attention Constraints:
  762. *@li in dense implementation of this algorithm, mg, ms, and mom will
  763. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  764. * and mom will not update in iterations during which the grad is zero.
  765. *@li the input tensors must have the same shape.
  766. *
  767. *@par Inputs:
  768. *@li var: A mutable tensor. Should be from a Variable().
  769. *@li mg: A mutable tensor. Has the same type as "var".
  770. * Should be from a Variable().
  771. *@li ms: A mutable tensor. Has the same type as "var".
  772. * Should be from a Variable().
  773. *@li mom: A mutable tensor. Has the same type as "var".
  774. * Should be from a Variable().
  775. *@li lr: A scalar. Has the same type as "var".
  776. *@li rho: A scalar. Has the same type as "var".
  777. *@li momentum: A tensor. Has the same type as "var".
  778. *@li epsilon: A scalar. Has the same type as "var".
  779. *@li grad: A tensor for the gradient. Has the same type as "var".
  780. *
  781. *@par Attributes:
  782. * use_locking: An optional bool. Defaults to "False".
  783. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  784. * by a lock; otherwise the behavior is undefined, but may exhibit less
  785. * contention.
  786. *
  787. *@par Outputs:
  788. * var: A mutable tensor. Has the same type as input "var".
  789. *
  790. *@par Third-party framework compatibility
  791. *Compatible with the TensorFlow operator ApplyCenteredRMSProp.
  792. *
  793. */
  794. REG_OP(ApplyCenteredRMSProp)
  795. .INPUT(var, TensorType::NumberType())
  796. .INPUT(mg, TensorType::NumberType())
  797. .INPUT(ms, TensorType::NumberType())
  798. .INPUT(mom, TensorType::NumberType())
  799. .INPUT(lr, TensorType::NumberType())
  800. .INPUT(rho, TensorType::NumberType())
  801. .INPUT(momentum, TensorType::NumberType())
  802. .INPUT(epsilon, TensorType::NumberType())
  803. .INPUT(grad, TensorType::NumberType())
  804. .OUTPUT(var, TensorType::NumberType())
  805. .ATTR(use_locking, Bool, false)
  806. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  807. /**
  808. *@brief Updates "var" according to the centered RMSProp algorithm.
  809. * The centered RMSProp algorithm uses an estimate of the centered second moment
  810. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  811. * uses the (uncentered) second moment. This often helps with training, but is
  812. * slightly more expensive in terms of computation and memory.
  813. *
  814. * t-1 mean previous period.
  815. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  816. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  817. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  818. * var <- var - mom\n
  819. *
  820. *@attention Constraints:
  821. *@li in dense implementation of this algorithm, mg, ms, and mom will
  822. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  823. * and mom will not update in iterations during which the grad is zero.
  824. *@li the input tensors must have the same shape.
  825. *
  826. *@par Inputs:
  827. *@li var: A mutable tensor. Should be from a Variable().
  828. *@li mg: A mutable tensor. Has the same type as "var".
  829. * Should be from a Variable().
  830. *@li ms: A mutable tensor. Has the same type as "var".
  831. * Should be from a Variable().
  832. *@li mom: A mutable tensor. Has the same type as "var".
  833. * Should be from a Variable().
  834. *@li lr: A scalar. Has the same type as "var".
  835. *@li rho: A scalar. Has the same type as "var".
  836. *@li momentum: A tensor. Has the same type as "var".
  837. *@li epsilon: A scalar. Has the same type as "var".
  838. *@li grad: A tensor for the gradient. Has the same type as "var".
  839. *
  840. *@par Attributes:
  841. * use_locking: An optional bool. Defaults to "False".
  842. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  843. * by a lock; otherwise the behavior is undefined, but may exhibit less
  844. * contention.
  845. *
  846. *@par Outputs:
  847. *@li var: A mutable Tensor. Has the same type as "var".
  848. *@li mg: A mutable Tensor. Has the same type as "mg".
  849. *@li ms: A mutable Tensor. Has the same type as "ms".
  850. *@li mom: A mutable Tensor. Has the same type as "mom".
  851. *@par Third-party framework compatibility
  852. *Compatible with the TensorFlow operator ApplyCenteredRMSPropD.
  853. *
  854. */
  855. REG_OP(ApplyCenteredRMSPropD)
  856. .INPUT(var, TensorType::NumberType())
  857. .INPUT(mg, TensorType::NumberType())
  858. .INPUT(ms, TensorType::NumberType())
  859. .INPUT(mom, TensorType::NumberType())
  860. .INPUT(lr, TensorType::NumberType())
  861. .INPUT(rho, TensorType::NumberType())
  862. .INPUT(momentum, TensorType::NumberType())
  863. .INPUT(epsilon, TensorType::NumberType())
  864. .INPUT(grad, TensorType::NumberType())
  865. .OUTPUT(var, TensorType::NumberType())
  866. .OUTPUT(mg, TensorType::NumberType())
  867. .OUTPUT(ms, TensorType::NumberType())
  868. .OUTPUT(mom, TensorType::NumberType())
  869. .ATTR(use_locking, Bool, false)
  870. .OP_END_FACTORY_REG(ApplyCenteredRMSPropD)
  871. /**
  872. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.
  873. * var -= delta * alpha
  874. *
  875. *@attention Constraints:
  876. * the input tensors must have the same shape.
  877. *
  878. *@par Inputs:
  879. *@li var: A mutable tensor. Should be from a Variable().
  880. *@li alpha: A scalar. Has the same type as "var".
  881. *@li delta: A tensor for the change. Has the same type as "var".
  882. *
  883. *@par Attributes:
  884. * use_locking: An optional bool. Defaults to "False".
  885. * If "True", updating of the "var" tensors is protected
  886. * by a lock; otherwise the behavior is undefined, but may exhibit less
  887. * contention.
  888. *
  889. *@par Outputs:
  890. * var: A mutable tensor. Has the same type as input "var".
  891. *
  892. *@par Third-party framework compatibility
  893. *Compatible with the TensorFlow operator ApplyGradientDescent.
  894. *
  895. */
  896. REG_OP(ApplyGradientDescent)
  897. .INPUT(var, TensorType::NumberType())
  898. .INPUT(alpha, TensorType::NumberType())
  899. .INPUT(delta, TensorType::NumberType())
  900. .OUTPUT(var, TensorType::NumberType())
  901. .ATTR(use_locking, Bool, false)
  902. .OP_END_FACTORY_REG(ApplyGradientDescent)
  903. /**
  904. *@brief Updates "var" according to the adagrad scheme.
  905. * accum += grad * grad\n
  906. * var -= lr * grad * (1 / sqrt(accum))
  907. *
  908. *@attention Constraints:
  909. * the input tensors must have the same shape.
  910. *
  911. *@par Inputs:
  912. *@li var: A mutable tensor. Should be from a Variable().
  913. *@li accum: A mutable tensor. Has the same type as "var".
  914. * Should be from a Variable().
  915. *@li lr: A scalar. Has the same type as "var".
  916. *@li grad: A tensor for the gradient. Has the same type as "var".
  917. *
  918. *@par Attributes:
  919. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  920. *@li use_locking: An optional bool. Defaults to "False".
  921. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  922. * by a lock; otherwise the behavior is undefined, but may exhibit less
  923. * contention.
  924. *
  925. *@par Outputs:
  926. * var: A mutable tensor. Has the same type as input "var".
  927. *
  928. *@par Third-party framework compatibility
  929. *Compatible with the TensorFlow operator ApplyAdagrad.
  930. *
  931. */
  932. REG_OP(ApplyAdagrad)
  933. .INPUT(var, TensorType::NumberType())
  934. .INPUT(accum, TensorType::NumberType())
  935. .INPUT(lr, TensorType::NumberType())
  936. .INPUT(grad, TensorType::NumberType())
  937. .OUTPUT(var, TensorType::NumberType())
  938. .ATTR(update_slots, Bool, true)
  939. .ATTR(use_locking, Bool, false)
  940. .OP_END_FACTORY_REG(ApplyAdagrad)
  941. /**
  942. *@brief Updates "var" according to the adagrad scheme.
  943. * accum += grad * grad\n
  944. * var -= lr * grad * (1 / sqrt(accum))
  945. *
  946. *@attention Constraints:
  947. * the input tensors must have the same shape.
  948. *
  949. *@par Inputs:
  950. *@li var: A mutable tensor. Should be from a Variable().
  951. *@li accum: A mutable tensor. Has the same type as "var".
  952. * Should be from a Variable().
  953. *@li lr: A scalar. Has the same type as "var".
  954. *@li grad: A tensor for the gradient. Has the same type as "var".
  955. *
  956. *@par Attributes:
  957. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  958. *@li use_locking: An optional bool. Defaults to "False".
  959. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  960. * by a lock; otherwise the behavior is undefined, but may exhibit less
  961. * contention.
  962. *
  963. *@par Outputs:
  964. *@li var: A mutable tensor. Has the same type as input "var".
  965. *@li accum: A mutable tensor. Has the same type as input "var".
  966. *
  967. *@par Third-party framework compatibility
  968. *Compatible with the TensorFlow operator ApplyAdagrad.
  969. *
  970. */
  971. REG_OP(ApplyAdagradD)
  972. .INPUT(var, TensorType::NumberType())
  973. .INPUT(accum, TensorType::NumberType())
  974. .INPUT(lr, TensorType::NumberType())
  975. .INPUT(grad, TensorType::NumberType())
  976. .OUTPUT(var, TensorType::NumberType())
  977. .OUTPUT(accum, TensorType::NumberType())
  978. .ATTR(update_slots, Bool, true)
  979. .ATTR(use_locking, Bool, false)
  980. .OP_END_FACTORY_REG(ApplyAdagradD)
  981. /**
  982. * @brief Updates "var" according to the adagradv2 scheme.
  983. * accum += grad * grad \n
  984. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  985. *
  986. * @par Inputs:
  987. * @li var: A mutable tensor. Must be one of the data types defined in
  988. * TensorType::NumberType(). Should be from a Variable().
  989. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  990. * Variable().
  991. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  992. * from a Variable().
  993. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  994. * from a Variable().
  995. * @li epsilon: A scalar. Has the same type as "var".
  996. *
  997. * @par Attributes:
  998. * @li update_slots: An optional bool. Defaults to "True".
  999. * If "True", "accum" will be updated
  1000. * @li use_locking: An optional bool. Defaults to "False".
  1001. * If "True", updating of the "var" tensor is protected by a lock;
  1002. * otherwise the behavior is undefined, but may exhibit less contention.
  1003. *
  1004. * @par Outputs:
  1005. * var: A mutable tensor. Has the same type as input "var".
  1006. *
  1007. * @attention Constraints:
  1008. * The input tensors must have the same shape.
  1009. *
  1010. * @par Third-party framework compatibility
  1011. * Compatible with the TensorFlow operator ApplyAdagrad.
  1012. *
  1013. */
  1014. REG_OP(ApplyAdagradV2)
  1015. .INPUT(var, TensorType::NumberType())
  1016. .INPUT(accum, TensorType::NumberType())
  1017. .INPUT(lr, TensorType::NumberType())
  1018. .INPUT(epsilon, TensorType::NumberType())
  1019. .INPUT(grad, TensorType::NumberType())
  1020. .OUTPUT(var, TensorType::NumberType())
  1021. .ATTR(update_slots, Bool, true)
  1022. .ATTR(use_locking, Bool, false)
  1023. .OP_END_FACTORY_REG(ApplyAdagradV2)
  1024. /**
  1025. * @brief Updates "var" according to the adagradv2 scheme.
  1026. * accum += grad * grad \n
  1027. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  1028. *
  1029. * @par Inputs:
  1030. * @li var: A mutable tensor. Must be one of the data types defined in
  1031. * TensorType::NumberType(). Should be from a Variable().
  1032. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  1033. * Variable().
  1034. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  1035. * from a Variable().
  1036. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  1037. * from a Variable().
  1038. *
  1039. * @par Attributes:
  1040. * @li epsilon: A scalar. Has the same type as "var".
  1041. * @li update_slots: An optional bool. Defaults to "True".
  1042. * If "True", "accum" will be updated
  1043. * @li use_locking: An optional bool. Defaults to "False".
  1044. * If "True", updating of the "var" tensor is protected by a lock;
  1045. * otherwise the behavior is undefined, but may exhibit less contention.
  1046. *
  1047. * @par Outputs:
  1048. * var: A mutable tensor. Has the same type as input "var".
  1049. *
  1050. * @attention Constraints:
  1051. * The input tensors must have the same shape.
  1052. *
  1053. * @par Third-party framework compatibility
  1054. * Compatible with the TensorFlow operator ApplyAdagrad.
  1055. *
  1056. */
  1057. REG_OP(ApplyAdagradV2D)
  1058. .INPUT(var, TensorType::NumberType())
  1059. .INPUT(accum, TensorType::NumberType())
  1060. .INPUT(lr, TensorType::NumberType())
  1061. .INPUT(grad, TensorType::NumberType())
  1062. .OUTPUT(var, TensorType::NumberType())
  1063. .OUTPUT(accum, TensorType::NumberType())
  1064. .REQUIRED_ATTR(epsilon, Float)
  1065. .ATTR(update_slots, Bool, true)
  1066. .ATTR(use_locking, Bool, false)
  1067. .OP_END_FACTORY_REG(ApplyAdagradV2D)
  1068. /**
  1069. *@brief Updates "var" according to the proximal adagrad scheme.
  1070. *@par Inputs:
  1071. *Eight inputs, including:
  1072. * @li var: A mutable Tensor. Must be one of the following types:
  1073. * TensorType::NumberType(). Should be a Variable Tensor.
  1074. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1075. * type as "var". Should be a Variable Tensor.
  1076. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1077. * Should be a Variable Tensor.
  1078. * @li grad: A Tensor of the same type as "var", for the gradient.
  1079. * @li lr: A Tensor of the same type as "var".
  1080. * Scaling factor. Must be a scalar.
  1081. * @li l1: A Tensor of the same type as "var".
  1082. * L1 regulariation. Must be a scalar.
  1083. * @li l2: A Tensor of the same type as "var".
  1084. * L2 regulariation. Must be a scalar.
  1085. * @li global_step: A Tensor of type int32 or int64.
  1086. * Training step number. Must be a scalar.
  1087. *@par Attributes:
  1088. *use_locking: An optional bool. Defaults to "False".
  1089. * If "True", updating of the var and accum tensors will be
  1090. * protected by a lock; otherwise the behavior is undefined,
  1091. * but may exhibit less contention.
  1092. *@par Outputs:
  1093. *var: A mutable Tensor. Has the same type as "var".
  1094. *@par Third-party framework compatibility
  1095. *Compatible with the TensorFlow operator ApplyAdagradDA.
  1096. */
  1097. REG_OP(ApplyAdagradDA)
  1098. .INPUT(var, TensorType::NumberType())
  1099. .INPUT(gradient_accumulator, TensorType::NumberType())
  1100. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1101. .INPUT(grad, TensorType::NumberType())
  1102. .INPUT(lr, TensorType::NumberType())
  1103. .INPUT(l1, TensorType::NumberType())
  1104. .INPUT(l2, TensorType::NumberType())
  1105. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1106. .OUTPUT(var, TensorType::NumberType())
  1107. .ATTR(use_locking, Bool, false)
  1108. .OP_END_FACTORY_REG(ApplyAdagradDA)
  1109. /**
  1110. *@brief Updates "var" according to the proximal adagrad scheme.
  1111. *@par Inputs:
  1112. *Eight inputs, including:
  1113. * @li var: A mutable Tensor. Must be one of the following types:
  1114. * TensorType::NumberType(). Should be a Variable Tensor.
  1115. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1116. * type as "var". Should be a Variable Tensor.
  1117. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1118. * Should be a Variable Tensor.
  1119. * @li grad: A Tensor of the same type as "var", for the gradient.
  1120. * @li lr: A Tensor of the same type as "var".
  1121. * Scaling factor. Must be a scalar.
  1122. * @li l1: A Tensor of the same type as "var".
  1123. * L1 regulariation. Must be a scalar.
  1124. * @li l2: A Tensor of the same type as "var".
  1125. * L2 regulariation. Must be a scalar.
  1126. * @li global_step: A Tensor of type int32 or int64.
  1127. * Training step number. Must be a scalar.
  1128. *@par Attributes:
  1129. *use_locking: An optional bool. Defaults to "False".
  1130. * If "True", updating of the var and accum tensors will be
  1131. * protected by a lock; otherwise the behavior is undefined,
  1132. * but may exhibit less contention.
  1133. *@par Outputs:
  1134. *var: A mutable Tensor. Has the same type as "var".
  1135. *gradient_accumulator: A mutable Tensor. Has the same type as "var".
  1136. *gradient_squared_accumulator: A mutable Tensor. Has the same type as "var".
  1137. *@par Third-party framework compatibility
  1138. *Compatible with the TensorFlow operator ApplyAdagradDA.
  1139. */
  1140. REG_OP(ApplyAdagradDAD)
  1141. .INPUT(var, TensorType::NumberType())
  1142. .INPUT(gradient_accumulator, TensorType::NumberType())
  1143. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1144. .INPUT(grad, TensorType::NumberType())
  1145. .INPUT(lr, TensorType::NumberType())
  1146. .INPUT(l1, TensorType::NumberType())
  1147. .INPUT(l2, TensorType::NumberType())
  1148. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1149. .OUTPUT(var, TensorType::NumberType())
  1150. .OUTPUT(gradient_accumulator, TensorType::NumberType())
  1151. .OUTPUT(gradient_squared_accumulator, TensorType::NumberType())
  1152. .ATTR(use_locking, Bool, false)
  1153. .OP_END_FACTORY_REG(ApplyAdagradDAD)
  1154. /**
  1155. *@brief Returns the dimension index in the destination data format given the one in
  1156. * the source data format.
  1157. *
  1158. *@par Inputs:
  1159. * x: A tensor of type int32 or int64.
  1160. * A Tensor with each element as a dimension index in source data format.
  1161. * Must be in the range [-4, 4).
  1162. *
  1163. *@par Attributes:
  1164. *@li src_format: An optional string. Defaults to NHWC.
  1165. * source data format. Must of length 4.
  1166. *@li dst_format: An optional string. Defaults to NCHW.
  1167. * destination data format. Must of length 4.
  1168. *
  1169. *@par Outputs:
  1170. * y: A tensor. Has the same type as "x". Must be in the range [0, 4).
  1171. *
  1172. *@par Third-party framework compatibility
  1173. *Compatible with the TensorFlow operator DataFormatDimMap.
  1174. *
  1175. */
  1176. REG_OP(DataFormatDimMap)
  1177. .INPUT(x, TensorType::IndexNumberType())
  1178. .ATTR(src_format, String, "NHWC")
  1179. .ATTR(dst_format, String, "NCHW")
  1180. .OUTPUT(y, TensorType::IndexNumberType())
  1181. .OP_END_FACTORY_REG(DataFormatDimMap)
  1182. /**
  1183. * @brief Implements stochastic gradient descent (optionally with momentum).
  1184. * Nesterov momentum is based on the formula from
  1185. * On the importance of initialization and momentum in deep learning.\n
  1186. * @par Inputs:
  1187. * @li parameters: A mutable tensor of type float16 or float32.\n
  1188. * Specifies the iterable of parameters to optimize or dicts defining parameter
  1189. * groups.
  1190. * @li gradient: A tensor of type float16 or float32.\n
  1191. * Specifies the gradient of training step.
  1192. * @li learning_rate: A tensor of type float16 or float32.\n
  1193. * Specifies the learing_rate of training step.
  1194. * @li accum: A tensor of type float16 or float32.
  1195. * Specifies the velocity of training step.
  1196. * @li momentum: A tensor of type float16 or float32.
  1197. * Specifies the momentum factor.
  1198. * @li stat: A tensor of type float16 or float32.
  1199. * Specifies the status representing the first step or not.
  1200. * @par Attributes:
  1201. * @li dampening: An optional float, specifying the dampening for momentum.
  1202. * Defaults to "0.0".
  1203. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  1204. * "0.0".
  1205. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  1206. * momentum. Defaults to "False".
  1207. * @par Outputs:
  1208. * parameters: A mutable tensor same as input "parameters".
  1209. * @see ApplyMomentum()
  1210. * @par Third-party framework compatibility
  1211. * @li Compatible with the PyTorch operator SGD.
  1212. */
  1213. REG_OP(SGD)
  1214. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1215. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  1216. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  1217. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  1218. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  1219. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  1220. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1221. .ATTR(dampening, Float, 0.0)
  1222. .ATTR(weight_decay, Float, 0.0)
  1223. .ATTR(nesterov, Bool, false)
  1224. .OP_END_FACTORY_REG(SGD)
  1225. /**
  1226. * @brief Updates "var" according to the RMSProp algorithm.
  1227. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  1228. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  1229. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  1230. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  1231. * var <- var - mom\n
  1232. *
  1233. * @par Inputs:
  1234. * @li var: A mutable tensor. Must be one of the data types defined in
  1235. * TensorType::NumberType(). Should be from a Variable().
  1236. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1237. * Variable().
  1238. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1239. * Variable().
  1240. * @li lr: A scalar. Must have the same type as "var".
  1241. * @li rho: A scalar. Must have the same type as "var".
  1242. * @li momentum: A scalar. Must have the same type as "var".
  1243. * @li epsilon: A scalar. Must have the same type as "var".
  1244. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1245. *
  1246. * @par Attributes:
  1247. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  1248. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  1249. * behavior is undefined, but may exhibit less contention.
  1250. *
  1251. * @par Outputs:
  1252. * var: A mutable tensor. Has the same type as input "var".
  1253. *
  1254. * @attention Constraints:
  1255. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will
  1256. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"
  1257. * will not update in iterations during which "grad" is 0.
  1258. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1259. *
  1260. * @par Third-party framework compatibility
  1261. * @li Compatible with the TensorFlow operator ApplyRMSProp.
  1262. */
  1263. REG_OP(ApplyRMSProp)
  1264. .INPUT(var, TensorType::NumberType())
  1265. .INPUT(ms, TensorType::NumberType())
  1266. .INPUT(mom, TensorType::NumberType())
  1267. .INPUT(lr, TensorType::NumberType())
  1268. .INPUT(rho, TensorType::NumberType())
  1269. .INPUT(momentum, TensorType::NumberType())
  1270. .INPUT(epsilon, TensorType::NumberType())
  1271. .INPUT(grad, TensorType::NumberType())
  1272. .OUTPUT(var, TensorType::NumberType())
  1273. .ATTR(use_locking, Bool, false)
  1274. .OP_END_FACTORY_REG(ApplyRMSProp)
  1275. /**
  1276. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  1277. * considered as an attribute.
  1278. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  1279. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  1280. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  1281. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  1282. * var <- var - mom
  1283. *
  1284. * @par Inputs:
  1285. * @li var: A mutable tensor. Must be one of the data types defined in
  1286. * TensorType::NumberType(). Should be from a Variable().
  1287. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1288. * Variable().
  1289. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1290. * Variable().
  1291. * @li lr: A scalar. Must have the same type as "var".
  1292. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1293. *
  1294. * @par Attributes:
  1295. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating
  1296. * of the "var", "ms", and "mom" tensors will be protected by a lock;
  1297. * otherwise the behavior is undefined, but may exhibit less contention.
  1298. * @li rho: A required scalar. Must have the same type as "var".
  1299. * @li momentum: A required scalar. Must have the same type as "var".
  1300. * @li epsilon: A required scalar. Must have the same type as "var".
  1301. *
  1302. * @par Outputs:
  1303. * var: A mutable tensor. Must have the same type as input "var".
  1304. *
  1305. * @attention Constraints:
  1306. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will
  1307. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"
  1308. * will not update in iterations during which "grad" is 0.
  1309. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1310. *
  1311. * @par Third-party framework compatibility
  1312. * @li Compatible with the TensorFlow operator ApplyRMSProp.
  1313. */
  1314. REG_OP(ApplyRMSPropD)
  1315. .INPUT(var, TensorType::NumberType())
  1316. .INPUT(ms, TensorType::NumberType())
  1317. .INPUT(mom, TensorType::NumberType())
  1318. .INPUT(lr, TensorType::NumberType())
  1319. .INPUT(grad, TensorType::NumberType())
  1320. .OUTPUT(var, TensorType::NumberType())
  1321. .OUTPUT(ms, TensorType::NumberType())
  1322. .OUTPUT(mom, TensorType::NumberType())
  1323. .REQUIRED_ATTR(rho, Float)
  1324. .REQUIRED_ATTR(momentum, Float)
  1325. .REQUIRED_ATTR(epsilon, Float)
  1326. .ATTR(use_locking, Bool, false)
  1327. .OP_END_FACTORY_REG(ApplyRMSPropD)
  1328. /**
  1329. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  1330. *@par Inputs:
  1331. *Six inputs, including:
  1332. * @li var: A mutable Tensor of type TensorType::NumberType().
  1333. * Should be from a Variable().
  1334. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1335. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1336. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1337. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1338. * @li grad: A Tensor of the same type as "var", for the gradient.
  1339. *@par Attributes:
  1340. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  1341. *@par Outputs:
  1342. *var: A mutable tensor. Must have the same type as input "var".
  1343. *@par Third-party framework compatibility
  1344. *Compatible with the TensorFlow operator ApplyProximalAdagrad.
  1345. */
  1346. REG_OP(ApplyProximalAdagrad)
  1347. .INPUT(var, TensorType::NumberType())
  1348. .INPUT(accum, TensorType::NumberType())
  1349. .INPUT(lr, TensorType::NumberType())
  1350. .INPUT(l1, TensorType::NumberType())
  1351. .INPUT(l2, TensorType::NumberType())
  1352. .INPUT(grad, TensorType::NumberType())
  1353. .OUTPUT(var, TensorType::NumberType())
  1354. .ATTR(use_locking, Bool, false)
  1355. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  1356. /**
  1357. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  1358. *@par Inputs:
  1359. *Six inputs, including:
  1360. * @li var: A mutable Tensor of type TensorType::NumberType().
  1361. * Should be from a Variable().
  1362. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1363. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1364. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1365. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1366. * @li grad: A Tensor of the same type as "var", for the gradient.
  1367. *@par Attributes:
  1368. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  1369. *@par Outputs:
  1370. * @li var: A mutable Tensor. Has the same type as "var".
  1371. * @li accum: A mutable Tensor. Has the same type as "var".
  1372. *@par Third-party framework compatibility
  1373. *Compatible with the TensorFlow operator ApplyProximalAdagradD.
  1374. */
  1375. REG_OP(ApplyProximalAdagradD)
  1376. .INPUT(var, TensorType::NumberType())
  1377. .INPUT(accum, TensorType::NumberType())
  1378. .INPUT(lr, TensorType::NumberType())
  1379. .INPUT(l1, TensorType::NumberType())
  1380. .INPUT(l2, TensorType::NumberType())
  1381. .INPUT(grad, TensorType::NumberType())
  1382. .OUTPUT(var, TensorType::NumberType())
  1383. .OUTPUT(accum, TensorType::NumberType())
  1384. .ATTR(use_locking, Bool, false)
  1385. .OP_END_FACTORY_REG(ApplyProximalAdagradD)
  1386. /**
  1387. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.
  1388. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1389. * Only the indices into the first dimensions of "var" and "accum" are updated.
  1390. *@par Inputs:
  1391. * Seven inputs, including:\n
  1392. * @li var: A mutable Tensor.\n
  1393. * TensorType::NumberType(). Should be a Variable Tensor.
  1394. * @li accum: A mutable Tensor of the same type as "var".\n
  1395. * Should be a Variable Tensor. Should be greater than or equal to zero.\n
  1396. * Accum and grad cannot be equal to zero at the same time.
  1397. * @li lr: A Tensor of the same type as "var".\n
  1398. * Scaling factor. Must be a scalar. Should be greater than zero.
  1399. * @li l1: A Tensor of the same type as "var".\n
  1400. * L1 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1401. * @li l2: A Tensor of the same type as "var".\n
  1402. * L2 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1403. * @li grad: A Tensor. Has the same type as "var".\n
  1404. * The gradient.
  1405. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  1406. * TensorType::IndexNumberType(). Can contain duplicate values.
  1407. *@par Attributes:
  1408. *use_locking: An optional bool. Defaults to "False".\n
  1409. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  1410. * If "False", the behavior is undefined, but may exhibit less contention.
  1411. *@par Outputs:
  1412. *var: A mutable Tensor. Has the same type as "var".
  1413. *@par Third-party framework compatibility
  1414. *Compatible with the TensorFlow operator SparseApplyProximalAdagrad.
  1415. */
  1416. REG_OP(SparseApplyProximalAdagrad)
  1417. .INPUT(var, TensorType::NumberType())
  1418. .INPUT(accum, TensorType::NumberType())
  1419. .INPUT(lr, TensorType::NumberType())
  1420. .INPUT(l1, TensorType::NumberType())
  1421. .INPUT(l2, TensorType::NumberType())
  1422. .INPUT(grad, TensorType::NumberType())
  1423. .INPUT(indices, TensorType::IndexNumberType())
  1424. .OUTPUT(var, TensorType::NumberType())
  1425. .ATTR(use_locking, Bool, false)
  1426. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  1427. /**
  1428. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  1429. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1430. * Only the indices into the first dimensions of "var" and "accum" are updated.
  1431. *@par Inputs:
  1432. * Seven inputs, including:\n
  1433. * @li var: A mutable Tensor.\n
  1434. * TensorType::NumberType(). Should be a Variable Tensor.
  1435. * @li accum: A mutable Tensor of the same type as "var".\n
  1436. * Should be a Variable Tensor. Should be greater than or equal to zero.\n
  1437. * Accum and grad cannot be equal to zero at the same time.
  1438. * @li lr: A Tensor of the same type as "var".\n
  1439. * Scaling factor. Must be a scalar. Should be greater than zero.
  1440. * @li l1: A Tensor of the same type as "var".\n
  1441. * L1 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1442. * @li l2: A Tensor of the same type as "var".\n
  1443. * L2 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1444. * @li grad: A Tensor. Has the same type as "var". \n
  1445. * The gradient.
  1446. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  1447. * TensorType::IndexNumberType(). Can contain duplicate values.
  1448. *@par Attributes:
  1449. *use_locking: An optional bool. Defaults to "False".\n
  1450. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  1451. * If "False", the behavior is undefined, but may exhibit less contention.
  1452. *@par Outputs:
  1453. *@li var: A mutable Tensor. Has the same type as "var".
  1454. *@li accum: A mutable Tensor. Has the same type as "var".
  1455. *@par Third-party framework compatibility
  1456. *Compatible with the TensorFlow operator SparseApplyProximalAdagrad.
  1457. */
  1458. REG_OP(SparseApplyProximalAdagradD)
  1459. .INPUT(var, TensorType::NumberType())
  1460. .INPUT(accum, TensorType::NumberType())
  1461. .INPUT(lr, TensorType::NumberType())
  1462. .INPUT(l1, TensorType::NumberType())
  1463. .INPUT(l2, TensorType::NumberType())
  1464. .INPUT(grad, TensorType::NumberType())
  1465. .INPUT(indices, TensorType::IndexNumberType())
  1466. .OUTPUT(var, TensorType::NumberType())
  1467. .OUTPUT(accum, TensorType::NumberType())
  1468. .ATTR(use_locking, Bool, false)
  1469. .OP_END_FACTORY_REG(SparseApplyProximalAdagradD)
  1470. /**
  1471. *@brief Updates "var" according to the Ftrl-proximal scheme.
  1472. *@par Inputs:
  1473. *Eight inputs, including:
  1474. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1475. * Should be a Variable Tensor.
  1476. * @li accum: A mutable Tensor of the same type as "var".
  1477. * Should be a Variable Tensor.
  1478. * @li linear: A mutable Tensor of the same type as "var".
  1479. * Should be a Variable Tensor.
  1480. * @li grad: A Tensor of the same type as "var", for the gradient.
  1481. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1482. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1483. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1484. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1485. *@par Attributes:
  1486. *use_locking: An optional bool. Defaults to "False".
  1487. * If "True", updating of the "var" and "accum" tensors will be
  1488. * protected by a lock; otherwise the behavior is undefined,
  1489. * but may exhibit less contention.
  1490. *@par Outputs:
  1491. *var: A mutable Tensor. Has the same type as "var".
  1492. *@par Third-party framework compatibility
  1493. *Compatible with the TensorFlow operator ApplyFtrl.
  1494. */
  1495. REG_OP(ApplyFtrl)
  1496. .INPUT(var, TensorType::NumberType())
  1497. .INPUT(accum, TensorType::NumberType())
  1498. .INPUT(linear, TensorType::NumberType())
  1499. .INPUT(grad, TensorType::NumberType())
  1500. .INPUT(lr, TensorType::NumberType())
  1501. .INPUT(l1, TensorType::NumberType())
  1502. .INPUT(l2, TensorType::NumberType())
  1503. .INPUT(lr_power, TensorType::NumberType())
  1504. .OUTPUT(var, TensorType::NumberType())
  1505. .ATTR(use_locking, Bool, false)
  1506. .OP_END_FACTORY_REG(ApplyFtrl)
  1507. /**
  1508. *@brief Updates "var" according to the Ftrl-proximal scheme.
  1509. *@par Inputs:
  1510. *Eight inputs, including:
  1511. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1512. * Should be a Variable Tensor.
  1513. * @li accum: A mutable Tensor of the same type as "var".
  1514. * Should be a Variable Tensor.
  1515. * @li linear: A mutable Tensor of the same type as "var".
  1516. * Should be a Variable Tensor.
  1517. * @li grad: A Tensor of the same type as "var", for the gradient.
  1518. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1519. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1520. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1521. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1522. *@par Attributes:
  1523. *use_locking: An optional bool. Defaults to "False".
  1524. * If "True", updating of the "var" and "accum" tensors will be
  1525. * protected by a lock; otherwise the behavior is undefined,
  1526. * but may exhibit less contention.
  1527. *@par Outputs:
  1528. *@li var: A mutable Tensor. Has the same type as "var".
  1529. *@li accum: A mutable Tensor. Has the same type as "accum".
  1530. *@li linear: A mutable Tensor. Has the same type as "linear".
  1531. *@par Third-party framework compatibility
  1532. *Compatible with the TensorFlow operator ApplyFtrl.
  1533. */
  1534. REG_OP(ApplyFtrlD)
  1535. .INPUT(var, TensorType::NumberType())
  1536. .INPUT(accum, TensorType::NumberType())
  1537. .INPUT(linear, TensorType::NumberType())
  1538. .INPUT(grad, TensorType::NumberType())
  1539. .INPUT(lr, TensorType::NumberType())
  1540. .INPUT(l1, TensorType::NumberType())
  1541. .INPUT(l2, TensorType::NumberType())
  1542. .INPUT(lr_power, TensorType::NumberType())
  1543. .OUTPUT(var, TensorType::NumberType())
  1544. .OUTPUT(accum, TensorType::NumberType())
  1545. .OUTPUT(linear, TensorType::NumberType())
  1546. .ATTR(use_locking, Bool, false)
  1547. .OP_END_FACTORY_REG(ApplyFtrlD)
  1548. /**
  1549. *@brief Update "var" according to the Ftrl-proximal scheme.
  1550. *@par Inputs:
  1551. *Nine inputs, including:
  1552. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1553. * Should be a Variable Tensor.
  1554. * @li accum: A mutable Tensor of the same type as "var".
  1555. * Should be a Variable Tensor.
  1556. * @li linear: A mutable Tensor of the same type as "var".
  1557. * Should be a Variable Tensor.
  1558. * @li grad: A Tensor of the same type as "var", for the gradient.
  1559. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1560. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1561. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1562. * @li l2_shrinkage: A Tensor of the same type as "var".
  1563. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1564. *@par Attributes:
  1565. *use_locking: An optional bool. Defaults to "False".
  1566. * If "True", updating of the "var" and "accum" tensors will be
  1567. * protected by a lock; otherwise the behavior is undefined,
  1568. * but may exhibit less contention.
  1569. *@par Outputs:
  1570. *var: A mutable Tensor. Has the same type as "var".
  1571. *@par Third-party framework compatibility
  1572. *Compatible with the TensorFlow operator ApplyFtrlV2.
  1573. */
  1574. REG_OP(ApplyFtrlV2)
  1575. .INPUT(var, TensorType::NumberType())
  1576. .INPUT(accum, TensorType::NumberType())
  1577. .INPUT(linear, TensorType::NumberType())
  1578. .INPUT(grad, TensorType::NumberType())
  1579. .INPUT(lr, TensorType::NumberType())
  1580. .INPUT(l1, TensorType::NumberType())
  1581. .INPUT(l2, TensorType::NumberType())
  1582. .INPUT(l2_shrinkage, TensorType::NumberType())
  1583. .INPUT(lr_power, TensorType::NumberType())
  1584. .OUTPUT(var, TensorType::NumberType())
  1585. .ATTR(use_locking, Bool, false)
  1586. .OP_END_FACTORY_REG(ApplyFtrlV2)
  1587. /**
  1588. *@brief Update "var" according to the Ftrl-proximal scheme.
  1589. *@par Inputs:
  1590. *Nine inputs, including:
  1591. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1592. * Should be a Variable Tensor.
  1593. * @li accum: A mutable Tensor of the same type as "var".
  1594. * Should be a Variable Tensor.
  1595. * @li linear: A mutable Tensor of the same type as "var".
  1596. * Should be a Variable Tensor.
  1597. * @li grad: A Tensor of the same type as "var", for the gradient.
  1598. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1599. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1600. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1601. * @li l2_shrinkage: A Tensor of the same type as "var".
  1602. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1603. *@par Attributes:
  1604. *use_locking: An optional bool. Defaults to "False".
  1605. * If "True", updating of the "var" and "accum" tensors will be
  1606. * protected by a lock; otherwise the behavior is undefined,
  1607. * but may exhibit less contention.
  1608. *@par Outputs:
  1609. *var: A mutable Tensor. Has the same type as "var".
  1610. *accum: A mutable Tensor. Has the same type as "accum".
  1611. *linear: A mutable Tensor. Has the same type as "linear".
  1612. *@par Third-party framework compatibility
  1613. *Compatible with the TensorFlow operator ApplyFtrlV2.
  1614. */
  1615. REG_OP(ApplyFtrlV2D)
  1616. .INPUT(var, TensorType::NumberType())
  1617. .INPUT(accum, TensorType::NumberType())
  1618. .INPUT(linear, TensorType::NumberType())
  1619. .INPUT(grad, TensorType::NumberType())
  1620. .INPUT(lr, TensorType::NumberType())
  1621. .INPUT(l1, TensorType::NumberType())
  1622. .INPUT(l2, TensorType::NumberType())
  1623. .INPUT(l2_shrinkage, TensorType::NumberType())
  1624. .INPUT(lr_power, TensorType::NumberType())
  1625. .OUTPUT(var, TensorType::NumberType())
  1626. .OUTPUT(accum, TensorType::NumberType())
  1627. .OUTPUT(linear, TensorType::NumberType())
  1628. .ATTR(use_locking, Bool, false)
  1629. .OP_END_FACTORY_REG(ApplyFtrlV2D)
  1630. /**
  1631. *@brief Updates "var" according to the Adam algorithm.
  1632. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  1633. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  1634. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  1635. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1636. *
  1637. *@attention Constraints:
  1638. * *The input tensors must have the same shape.*
  1639. *
  1640. *@par Inputs:
  1641. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1642. * Should be from a Variable().
  1643. *@li m: A mutable Tensor of the same type as "var".
  1644. * Should be from a Variable().
  1645. *@li v: A mutable Tensor of the same type as "var".
  1646. * Should be from a Variable().
  1647. *@li beta1_power: A scalar of the same type as "var".
  1648. *@li beta2_power: A scalar of the same type as "var".
  1649. *@li lr: learning_rate. A scalar of the same type as "var".
  1650. *@li beta1: A scalar of the same type as "var".
  1651. *@li beta2: A scalar of the same type as "var".
  1652. *@li epsilon: A scalar of the same type as "var".
  1653. *@li grad: A Tensor of the same type as "var", for the gradient.
  1654. *
  1655. *@par Attributes:
  1656. *@li use_locking: An optional bool. Defaults to "False".
  1657. * If "True", updating of the "var", m", and "v" tensors will be protected
  1658. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1659. * contention.
  1660. *@li use_nesterov: An optional bool. Defaults to "False".
  1661. If "True", uses the nesterov update.
  1662. *
  1663. *@par Outputs:
  1664. * var: A mutable Tensor. Has the same type as intput "var".
  1665. *@par Third-party framework compatibility
  1666. *Compatible with the TensorFlow operator ApplyAdam.
  1667. */
  1668. REG_OP(ApplyAdam)
  1669. .INPUT(var, TensorType::NumberType())
  1670. .INPUT(m, TensorType::NumberType())
  1671. .INPUT(v, TensorType::NumberType())
  1672. .INPUT(beta1_power, TensorType::NumberType())
  1673. .INPUT(beta2_power, TensorType::NumberType())
  1674. .INPUT(lr, TensorType::NumberType())
  1675. .INPUT(beta1, TensorType::NumberType())
  1676. .INPUT(beta2, TensorType::NumberType())
  1677. .INPUT(epsilon, TensorType::NumberType())
  1678. .INPUT(grad, TensorType::NumberType())
  1679. .OUTPUT(var, TensorType::NumberType())
  1680. .ATTR(use_locking, Bool, false)
  1681. .ATTR(use_nesterov, Bool, false)
  1682. .OP_END_FACTORY_REG(ApplyAdam)
  1683. /**
  1684. *@brief Updates "var" according to the Adam algorithm.
  1685. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  1686. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  1687. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  1688. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1689. *
  1690. *@attention Constraints:
  1691. * *The input tensors must have the same shape.*
  1692. *
  1693. *@par Inputs:
  1694. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1695. * Should be from a Variable().
  1696. *@li m: A mutable Tensor of the same type as "var".
  1697. * Should be from a Variable().
  1698. *@li v: A mutable Tensor of the same type as "var".
  1699. * Should be from a Variable().
  1700. *@li beta1_power: A scalar of the same type as "var".
  1701. *@li beta2_power: A scalar of the same type as "var".
  1702. *@li lr: learning_rate. A scalar of the same type as "var".
  1703. *@li beta1: A scalar of the same type as "var".
  1704. *@li beta2: A scalar of the same type as "var".
  1705. *@li epsilon: A scalar of the same type as "var".
  1706. *@li grad: A Tensor of the same type as "var", for the gradient.
  1707. *
  1708. *@par Attributes:
  1709. *@li use_locking: An optional bool. Defaults to "False".
  1710. * If "True", updating of the "var", m", and "v" tensors will be protected
  1711. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1712. * contention.
  1713. *@li use_nesterov: An optional bool. Defaults to "False".
  1714. If "True", uses the nesterov update.
  1715. *
  1716. *@par Outputs:
  1717. *@li var: A mutable tensor. Has the same type as input "var".
  1718. *@li m: A mutable tensor. Has the same type as input "m".
  1719. *@li v: A mutable tensor. Has the same type as input "v".
  1720. *@par Third-party framework compatibility
  1721. *Compatible with the TensorFlow operator ApplyAdam.
  1722. */
  1723. REG_OP(ApplyAdamD)
  1724. .INPUT(var, TensorType::NumberType())
  1725. .INPUT(m, TensorType::NumberType())
  1726. .INPUT(v, TensorType::NumberType())
  1727. .INPUT(beta1_power, TensorType::NumberType())
  1728. .INPUT(beta2_power, TensorType::NumberType())
  1729. .INPUT(lr, TensorType::NumberType())
  1730. .INPUT(beta1, TensorType::NumberType())
  1731. .INPUT(beta2, TensorType::NumberType())
  1732. .INPUT(epsilon, TensorType::NumberType())
  1733. .INPUT(grad, TensorType::NumberType())
  1734. .OUTPUT(var, TensorType::NumberType())
  1735. .OUTPUT(m, TensorType::NumberType())
  1736. .OUTPUT(v, TensorType::NumberType())
  1737. .ATTR(use_locking, Bool, false)
  1738. .ATTR(use_nesterov, Bool, false)
  1739. .OP_END_FACTORY_REG(ApplyAdamD)
  1740. /**
  1741. *@brief Updates "var" according to the proximal adadelta scheme.
  1742. *@par Inputs:
  1743. *Seven inputs, including:
  1744. * @li var: A mutable Tensor of type TensorType::NumberType().
  1745. * Should be a Variable Tensor.
  1746. * @li accum: A mutable Tensor of the same type as "var".
  1747. * Should be a Variable Tensor.
  1748. * @li accum_update: A mutable Tensor of the same type as "var".
  1749. * Should be a Variable Tensor.
  1750. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1751. * @li rho: A scalar of the same type as "var", for the decay factor.
  1752. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1753. * @li grad: A Tensor of the same type as "var", for the gradient.
  1754. *@par Attributes:
  1755. *use_locking: An optional bool. Defaults to "False".
  1756. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1757. * protected by a lock; otherwise the behavior is undefined,
  1758. * but may exhibit less contention.
  1759. *@par Outputs:
  1760. *var: A mutable Tensor. Has the same type as "var".
  1761. *@par Third-party framework compatibility
  1762. * Compatible with the TensorFlow operator ApplyAdadelta.
  1763. */
  1764. REG_OP(ApplyAdadelta)
  1765. .INPUT(var, TensorType::NumberType())
  1766. .INPUT(accum, TensorType::NumberType())
  1767. .INPUT(accum_update, TensorType::NumberType())
  1768. .INPUT(lr, TensorType::NumberType())
  1769. .INPUT(rho, TensorType::NumberType())
  1770. .INPUT(epsilon, TensorType::NumberType())
  1771. .INPUT(grad, TensorType::NumberType())
  1772. .OUTPUT(var, TensorType::NumberType())
  1773. .ATTR(use_locking, Bool, false)
  1774. .OP_END_FACTORY_REG(ApplyAdadelta)
  1775. /**
  1776. *@brief Updates "var" according to the proximal adadelta scheme.
  1777. *@par Inputs:
  1778. *Seven inputs, including:
  1779. * @li var: A mutable Tensor of type TensorType::NumberType().
  1780. * Should be a Variable Tensor.
  1781. * @li accum: A mutable Tensor of the same type as "var".
  1782. * Should be a Variable Tensor.
  1783. * @li accum_update: A mutable Tensor of the same type as "var".
  1784. * Should be a Variable Tensor.
  1785. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1786. * @li rho: A scalar of the same type as "var", for the decay factor.
  1787. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1788. * @li grad: A Tensor of the same type as "var", for the gradient.
  1789. *@par Attributes:
  1790. *use_locking: An optional bool. Defaults to "False".
  1791. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1792. * protected by a lock; otherwise the behavior is undefined,
  1793. * but may exhibit less contention.
  1794. *@par Outputs:
  1795. *@li var: A mutable Tensor. Has the same type as "var".
  1796. *@li accum: A mutable Tensor. Has the same type as "var".
  1797. *@li accum_update: A mutable Tensor. Has the same type as "var".
  1798. *@par Third-party framework compatibility
  1799. * Compatible with the TensorFlow operator ApplyAdadelta.
  1800. */
  1801. REG_OP(ApplyAdadeltaD)
  1802. .INPUT(var, TensorType::NumberType())
  1803. .INPUT(accum, TensorType::NumberType())
  1804. .INPUT(accum_update, TensorType::NumberType())
  1805. .INPUT(lr, TensorType::NumberType())
  1806. .INPUT(rho, TensorType::NumberType())
  1807. .INPUT(epsilon, TensorType::NumberType())
  1808. .INPUT(grad, TensorType::NumberType())
  1809. .OUTPUT(var, TensorType::NumberType())
  1810. .OUTPUT(accum, TensorType::NumberType())
  1811. .OUTPUT(accum_update, TensorType::NumberType())
  1812. .ATTR(use_locking, Bool, false)
  1813. .OP_END_FACTORY_REG(ApplyAdadeltaD)
  1814. /**
  1815. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1816. * accum = accum * momentum + x1 * x2 \n
  1817. * if use_nesterov is True: \n
  1818. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1819. * else:\n
  1820. * var -= accum * lr
  1821. *
  1822. * @par Inputs:
  1823. * Six inputs, including:
  1824. * @li var: A mutable Tensor has type TensorType::NumberType().
  1825. * Should be a Variable Tensor.
  1826. * @li accum: A mutable Tensor has the same type as "var".
  1827. * Should be a Variable Tensor.
  1828. * @li lr: A scalar has the same type as "var", for the scaling factor.
  1829. * @li x1: A Tensor has type TensorType::NumberType().
  1830. * @li momentum: A scalar has the same type as "var".
  1831. * @li x2: A scalar has the same type as "var".
  1832. *
  1833. * @par Attributes:
  1834. * Two attributes, including:
  1835. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1836. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1837. * so in the end, the var you get is actually var - lr * momentum * accum.
  1838. * @li use_locking: An optional bool. Defaults to "False". \n
  1839. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1840. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1841. *
  1842. * @par Outputs:
  1843. * Two outputs, including:
  1844. * @li var: A mutable Tensor has the same type as "var".
  1845. * @li accum: A mutable Tensor has the same type as "var".
  1846. */
  1847. REG_OP(FusedMulApplyMomentum)
  1848. .INPUT(var, TensorType::NumberType())
  1849. .INPUT(accum, TensorType::NumberType())
  1850. .INPUT(lr, TensorType::NumberType())
  1851. .INPUT(x1, TensorType::NumberType())
  1852. .INPUT(momentum, TensorType::NumberType())
  1853. .INPUT(x2, TensorType::NumberType())
  1854. .OUTPUT(var, TensorType::NumberType())
  1855. .OUTPUT(accum, TensorType::NumberType())
  1856. .ATTR(use_nesterov, Bool, false)
  1857. .ATTR(use_locking, Bool, false)
  1858. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  1859. /**
  1860. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1861. * accum = accum * momentum + x1 * x2 \n
  1862. * if use_nesterov is True: \n
  1863. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1864. * else: \n
  1865. * var -= accum * lr
  1866. *
  1867. * @par Inputs:
  1868. * Seven inputs, including:
  1869. * @li var: A mutable Tensor of type float32.
  1870. * Should be a Variable Tensor.
  1871. * @li accum: A mutable Tensor has type TensorType::NumberType().
  1872. * Should be a Variable Tensor.
  1873. * @li lr: A scalar has the same type as "accum", for the scaling factor.
  1874. * @li x1: A Tensor has the same type as "accum".
  1875. * @li momentum: A scalar has the same type as "accum".
  1876. * @li x2: A scalar has the same type as "accum".
  1877. * @li var_copy: A Tensor has type float16.
  1878. *
  1879. * @par Attributes:
  1880. * Two Attributes, including:
  1881. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1882. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1883. * so in the end, the var you get is actually var - lr * momentum * accum.
  1884. * @li use_locking: An optional bool. Defaults to "False". \n
  1885. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1886. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1887. *
  1888. * @par Outputs:
  1889. * Three outputs, including:
  1890. * @li var: A Tensor has the type float32.
  1891. * @li var_copy: A Tensor has the type float16.
  1892. * @li accum: A Tensor has the same type as input "accum".
  1893. */
  1894. REG_OP(FusedMulApplyMomentumExtern)
  1895. .INPUT(var, TensorType(DT_FLOAT))
  1896. .INPUT(accum, TensorType::NumberType())
  1897. .INPUT(lr, TensorType::NumberType())
  1898. .INPUT(x1, TensorType::NumberType())
  1899. .INPUT(momentum, TensorType::NumberType())
  1900. .INPUT(x2, TensorType::NumberType())
  1901. .INPUT(var_copy, TensorType(DT_FLOAT16))
  1902. .OUTPUT(var, TensorType(DT_FLOAT))
  1903. .OUTPUT(var_copy, TensorType(DT_FLOAT16))
  1904. .OUTPUT(accum, TensorType::NumberType())
  1905. .ATTR(use_nesterov, Bool, false)
  1906. .ATTR(use_locking, Bool, false)
  1907. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  1908. /**
  1909. *@brief Update "g" according to the LARS algorithm.
  1910. *@par Inputs:
  1911. *Four inputs, including:
  1912. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1913. * @li g: A Tensor of the same type and shape as "w".
  1914. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1915. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1916. *@par Attributes:
  1917. *Three Attributes, including:
  1918. * @li hyperpara: An optional float. Default value is 0.001.
  1919. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1920. * @li use_clip: An optional bool. Defaults to "False".\n
  1921. * If "True", updating learning rate.
  1922. *@par Outputs:
  1923. *g_new: Tensor of the same type as "w".
  1924. */
  1925. REG_OP(LarsV2)
  1926. .INPUT(w, TensorType(DT_FLOAT))
  1927. .INPUT(g, TensorType(DT_FLOAT))
  1928. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1929. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1930. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1931. .ATTR(hyperpara, Float, 0.001)
  1932. .ATTR(epsilon, Float, 0.00001)
  1933. .ATTR(use_clip, Bool, false)
  1934. .OP_END_FACTORY_REG(LarsV2)
  1935. /**
  1936. *@brief Update "g" according to the LARS algorithm.
  1937. *@par Inputs:
  1938. *Six inputs, including:
  1939. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1940. * @li g: A Tensor of the same type and shape as "w".
  1941. * @li w_square_sum: A Tensor of square_sum(w), has the same type as "w", Must be a scalar.
  1942. * @li g_square_sum: A Tensor of square(g), has the same type as "w", Must be a scalar.
  1943. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1944. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1945. *@par Attributes:
  1946. *Three Attributes, including:
  1947. * @li hyperpara: An optional float. Default value is 0.001.
  1948. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1949. * @li use_clip: An optional bool. Defaults to "False".\n
  1950. * If "True", updating learning rate.
  1951. *@par Outputs:
  1952. *g_new: Tensor of the same type as "w".
  1953. */
  1954. REG_OP(LarsV2Update)
  1955. .INPUT(w, TensorType(DT_FLOAT))
  1956. .INPUT(g, TensorType(DT_FLOAT))
  1957. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  1958. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  1959. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1960. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1961. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1962. .ATTR(hyperpara, Float, 0.001)
  1963. .ATTR(epsilon, Float, 0.00001)
  1964. .ATTR(use_clip, Bool, false)
  1965. .OP_END_FACTORY_REG(LarsV2Update)
  1966. /**
  1967. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1968. * @par Inputs:
  1969. * Nine inputs, including:
  1970. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1971. * Should be a Variable Tensor.
  1972. * @li accum: A mutable Tensor of the same type as "var".
  1973. * Should be a Variable Tensor. The value of accum must be greater than 0.
  1974. * @li linear: A mutable Tensor of the same type as "var".
  1975. * Should be a Variable Tensor.
  1976. * @li grad: A Tensor of the same type as "var", for the gradient.
  1977. * @li indices: A vector of indices into the first dimension of var and accum.
  1978. * The value of indices must be unique. Otherwise, the result is unpredictable.
  1979. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1980. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1981. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1982. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1983. * @par Attributes:
  1984. * use_locking: An optional bool. Defaults to "False".
  1985. * If "True", updating of the "var" and "accum" tensors will be
  1986. * protected by a lock; otherwise the behavior is undefined,
  1987. * but may exhibit less contention.
  1988. * @par Outputs:
  1989. * var: A Tensor. Has the same type and format as input "var".
  1990. * @par Third-party framework compatibility
  1991. * Compatible with the TensorFlow operator SparseApplyFtrl.
  1992. */
  1993. REG_OP(SparseApplyFtrl)
  1994. .INPUT(var, TensorType({DT_FLOAT}))
  1995. .INPUT(accum, TensorType({DT_FLOAT}))
  1996. .INPUT(linear, TensorType({DT_FLOAT}))
  1997. .INPUT(grad, TensorType({DT_FLOAT}))
  1998. .INPUT(indices, TensorType({DT_INT32}))
  1999. .INPUT(lr, TensorType({DT_FLOAT}))
  2000. .INPUT(l1, TensorType({DT_FLOAT}))
  2001. .INPUT(l2, TensorType({DT_FLOAT}))
  2002. .INPUT(lr_power, TensorType({DT_FLOAT}))
  2003. .OUTPUT(var, TensorType({DT_FLOAT}))
  2004. .ATTR(use_locking, Bool, false)
  2005. .OP_END_FACTORY_REG(SparseApplyFtrl)
  2006. /**
  2007. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  2008. * @par Inputs:
  2009. * Five inputs, including:
  2010. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2011. * Should be a Variable Tensor.
  2012. * @li accum: A mutable Tensor of the same type as "var".
  2013. * Should be a Variable Tensor. The value of accum must be greater than 0.
  2014. * @li linear: A mutable Tensor of the same type as "var".
  2015. * Should be a Variable Tensor.
  2016. * @li grad: A Tensor of the same type as "var", for the gradient.
  2017. * @li indices: A vector of indices into the first dimension of var and accum.
  2018. * The value of indices must be unique. Otherwise, the result is unpredictable.
  2019. * @par Attributes:
  2020. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2021. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2022. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2023. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2024. * @li use_locking: An optional bool. Defaults to "False".
  2025. * If "True", updating of the "var" and "accum" tensors will be
  2026. * protected by a lock; otherwise the behavior is undefined,
  2027. * but may exhibit less contention.
  2028. * @par Outputs:
  2029. * @li var: A Tensor. Has the same type and format as input "var".
  2030. * @li accum: A Tensor. Has the same type and format as input "accum".
  2031. * @li linear: A Tensor. Has the same type and format as input "linear".
  2032. * @par Third-party framework compatibility
  2033. * Compatible with the TensorFlow operator SparseApplyFtrl.
  2034. */
  2035. REG_OP(SparseApplyFtrlD)
  2036. .INPUT(var, TensorType({DT_FLOAT}))
  2037. .INPUT(accum, TensorType({DT_FLOAT}))
  2038. .INPUT(linear, TensorType({DT_FLOAT}))
  2039. .INPUT(grad, TensorType({DT_FLOAT}))
  2040. .INPUT(indices, TensorType({DT_INT32}))
  2041. .OUTPUT(var, TensorType({DT_FLOAT}))
  2042. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2043. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2044. .REQUIRED_ATTR(lr, Float)
  2045. .REQUIRED_ATTR(l1, Float)
  2046. .REQUIRED_ATTR(l2, Float)
  2047. .REQUIRED_ATTR(lr_power, Float)
  2048. .ATTR(use_locking, Bool, false)
  2049. .OP_END_FACTORY_REG(SparseApplyFtrlD)
  2050. /**
  2051. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  2052. * That is for rows we have grad for, "var", "accum" and "linear" are updated.
  2053. * @par Inputs:
  2054. * Ten inputs, including:
  2055. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2056. * Should be a Variable Tensor.
  2057. * @li accum: A mutable Tensor of the same type as "var".
  2058. * Should be a Variable Tensor.
  2059. * @li linear: A mutable Tensor of the same type as "var".
  2060. * Should be a Variable Tensor.
  2061. * @li grad: A Tensor of the same type as "var", for the gradient.
  2062. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  2063. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2064. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2065. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2066. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2067. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2068. * @par Attributes:
  2069. * use_locking: An optional bool. Defaults to "False".
  2070. * If "True", updating of the "var" and "accum" tensors will be
  2071. * protected by a lock; otherwise the behavior is undefined,
  2072. * but may exhibit less contention.
  2073. * @par Outputs:
  2074. * var: A Tensor. Has the same type and format as input "var".
  2075. * @par Third-party framework compatibility
  2076. * Compatible with the TensorFlow operator SparseApplyFtrlV2.
  2077. */
  2078. REG_OP(SparseApplyFtrlV2)
  2079. .INPUT(var, TensorType({DT_FLOAT}))
  2080. .INPUT(accum, TensorType({DT_FLOAT}))
  2081. .INPUT(linear, TensorType({DT_FLOAT}))
  2082. .INPUT(grad, TensorType({DT_FLOAT}))
  2083. .INPUT(indices, TensorType({DT_INT32}))
  2084. .INPUT(lr, TensorType({DT_FLOAT}))
  2085. .INPUT(l1, TensorType({DT_FLOAT}))
  2086. .INPUT(l2, TensorType({DT_FLOAT}))
  2087. .INPUT(l2_shrinkage, TensorType({DT_FLOAT}))
  2088. .INPUT(lr_power, TensorType({DT_FLOAT}))
  2089. .OUTPUT(var, TensorType({DT_FLOAT}))
  2090. .ATTR(use_locking, Bool, false)
  2091. .OP_END_FACTORY_REG(SparseApplyFtrlV2)
  2092. /**
  2093. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  2094. * That is for rows we have grad for, "var", "accum" and "linear" are updated.
  2095. * @par Inputs:
  2096. * Five inputs, including:
  2097. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2098. * Should be a Variable Tensor.
  2099. * @li accum: A mutable Tensor of the same type as "var".
  2100. * Should be a Variable Tensor.
  2101. * @li linear: A mutable Tensor of the same type as "var".
  2102. * Should be a Variable Tensor.
  2103. * @li grad: A Tensor of the same type as "var", for the gradient.
  2104. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  2105. * @par Attributes:
  2106. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2107. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2108. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2109. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2110. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2111. * @li use_locking: An optional bool. Defaults to "False".
  2112. * If "True", updating of the "var" and "accum" tensors will be
  2113. * protected by a lock; otherwise the behavior is undefined,
  2114. * but may exhibit less contention.
  2115. * @par Outputs:
  2116. * @li var: A Tensor. Has the same type and format as input "var".
  2117. * @li accum: A Tensor. Has the same type and format as input "accum".
  2118. * @li linear: A Tensor. Has the same type and format as input "linear".
  2119. * @par Third-party framework compatibility
  2120. * Compatible with the TensorFlow operator SparseApplyFtrlV2D.
  2121. */
  2122. REG_OP(SparseApplyFtrlV2D)
  2123. .INPUT(var, TensorType({DT_FLOAT}))
  2124. .INPUT(accum, TensorType({DT_FLOAT}))
  2125. .INPUT(linear, TensorType({DT_FLOAT}))
  2126. .INPUT(grad, TensorType({DT_FLOAT}))
  2127. .INPUT(indices, TensorType({DT_INT32}))
  2128. .OUTPUT(var, TensorType({DT_FLOAT}))
  2129. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2130. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2131. .REQUIRED_ATTR(lr, Float)
  2132. .REQUIRED_ATTR(l1, Float)
  2133. .REQUIRED_ATTR(l2, Float)
  2134. .REQUIRED_ATTR(l2_shrinkage, Float)
  2135. .REQUIRED_ATTR(lr_power, Float)
  2136. .ATTR(use_locking, Bool, false)
  2137. .OP_END_FACTORY_REG(SparseApplyFtrlV2D)
  2138. /**
  2139. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2140. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  2141. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  2142. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  2143. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  2144. * var <- var - mom\n
  2145. *
  2146. * @par Inputs:
  2147. * Nine inputs, including:
  2148. * @li var: A mutable tensor. Must be one of the data types defined in\n
  2149. * TensorType::NumberType(). Should be from a Variable().
  2150. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2151. * Variable().
  2152. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2153. * Variable().
  2154. * @li lr: A scalar. Must have the same type as "var".
  2155. * @li rho: A scalar. Must have the same type as "var".
  2156. * @li momentum: A scalar. Must have the same type as "var".
  2157. * @li epsilon: A scalar. Must have the same type as "var".
  2158. * @li grad: A tensor, specifying the gradient.
  2159. * @li indices: A vector of indices into the first dimension of "var", "mom" and "ms".
  2160. *
  2161. * @par Attributes:
  2162. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2163. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  2164. * behavior is undefined, but may exhibit less contention.
  2165. *
  2166. * @par Outputs:
  2167. * var: A mutable tensor. Has the same type as input "var".
  2168. *
  2169. * @attention Constraints:
  2170. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2171. * in iterations during which "grad" is 0.
  2172. * @li The input tensors "var", "ms", and "mom" must have the same shape.
  2173. *
  2174. * @par Third-party framework compatibility
  2175. * Compatible with the TensorFlow operator SparseApplyRMSProp.
  2176. */
  2177. REG_OP(SparseApplyRMSProp)
  2178. .INPUT(var, TensorType::NumberType())
  2179. .INPUT(ms, TensorType::NumberType())
  2180. .INPUT(mom, TensorType::NumberType())
  2181. .INPUT(lr, TensorType::NumberType())
  2182. .INPUT(rho, TensorType::NumberType())
  2183. .INPUT(momentum, TensorType::NumberType())
  2184. .INPUT(epsilon, TensorType::NumberType())
  2185. .INPUT(grad, TensorType::NumberType())
  2186. .INPUT(indices, TensorType::IndexNumberType())
  2187. .OUTPUT(var, TensorType::NumberType())
  2188. .ATTR(use_locking, Bool, false)
  2189. .OP_END_FACTORY_REG(SparseApplyRMSProp)
  2190. /**
  2191. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2192. * a const input will be considered as an attribute.\n
  2193. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  2194. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  2195. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  2196. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  2197. * var <- var - mom
  2198. *
  2199. * @par Inputs:
  2200. * Six inputs, including:
  2201. * @li var: A mutable tensor. Must be one of the data types defined in
  2202. * TensorType::NumberType(). Should be from a Variable().
  2203. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2204. * Variable().
  2205. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2206. * Variable().
  2207. * @li lr: A scalar. Must have the same type as "var".
  2208. * @li grad: A tensor, specifying the gradient.
  2209. *
  2210. * @par Attributes:
  2211. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2212. * updating of the "var", "ms", and "mom" tensors will be protected by a lock;
  2213. * otherwise the behavior is undefined, but may exhibit less contention.
  2214. * @li rho: A required scalar. Must have the same type as "var".
  2215. * @li momentum: A required scalar. Must have the same type as "var".
  2216. * @li epsilon: A required scalar. Must have the same type as "var".
  2217. *
  2218. * @par Outputs:
  2219. * @li var: A mutable tensor. Must have the same type as input "var".
  2220. * @li ms: A mutable tensor. Must have the same type as input "ms".
  2221. * @li mom: A mutable tensor. Must have the same type as input "mom".
  2222. *
  2223. * @attention Constraints:
  2224. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2225. * in iterations during which "grad" is 0.
  2226. * @li The input tensors "var", "ms" and "mom" must have the same shape.
  2227. */
  2228. REG_OP(SparseApplyRMSPropD)
  2229. .INPUT(var, TensorType::NumberType())
  2230. .INPUT(ms, TensorType::NumberType())
  2231. .INPUT(mom, TensorType::NumberType())
  2232. .INPUT(lr, TensorType::NumberType())
  2233. .INPUT(grad, TensorType::NumberType())
  2234. .INPUT(indices, TensorType::IndexNumberType())
  2235. .OUTPUT(var, TensorType::NumberType())
  2236. .OUTPUT(ms, TensorType::NumberType())
  2237. .OUTPUT(mom, TensorType::NumberType())
  2238. .REQUIRED_ATTR(rho, Float)
  2239. .REQUIRED_ATTR(momentum, Float)
  2240. .REQUIRED_ATTR(epsilon, Float)
  2241. .ATTR(use_locking, Bool, false)
  2242. .OP_END_FACTORY_REG(SparseApplyRMSPropD)
  2243. /**
  2244. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2245. * accum <- rho * accum + (1 - rho) * grad.square()\n
  2246. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n
  2247. * var <- var - update * lr\n
  2248. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n
  2249. *
  2250. * @par Inputs:
  2251. * Eight inputs, including:
  2252. * @li var: A mutable tensor. Must be one of the data types defined in\n
  2253. * TensorType::NumberType(). Should be from a Variable().
  2254. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2255. * Variable().
  2256. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2257. * Variable().
  2258. * @li lr: A scalar. Must have the same type as "var".
  2259. * @li rho: A scalar. Must have the same type as "var".
  2260. * @li epsilon: A scalar. Must have the same type as "var".
  2261. * @li grad: A tensor, specifying the gradient.
  2262. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2263. *
  2264. * @par Attributes:
  2265. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2266. * the "var", "accum", and "accum_update" tensors will be protected by a lock; otherwise the
  2267. * behavior is undefined, but may exhibit less contention.
  2268. *
  2269. * @par Outputs:
  2270. * var: A mutable tensor. Has the same type as input "var".
  2271. *
  2272. * @attention Constraints:
  2273. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2274. * in iterations during which "grad" is 0.
  2275. * @li The input tensors "var", "accum", and "accum_update" must have the same shape.
  2276. *
  2277. * @par Third-party framework compatibility
  2278. * Compatible with the TensorFlow operator SparseApplyAdadelta.
  2279. */
  2280. REG_OP(SparseApplyAdadelta)
  2281. .INPUT(var, TensorType::NumberType())
  2282. .INPUT(accum, TensorType::NumberType())
  2283. .INPUT(accum_update, TensorType::NumberType())
  2284. .INPUT(lr, TensorType::NumberType())
  2285. .INPUT(rho, TensorType::NumberType())
  2286. .INPUT(epsilon, TensorType::NumberType())
  2287. .INPUT(grad, TensorType::NumberType())
  2288. .INPUT(indices, TensorType::IndexNumberType())
  2289. .OUTPUT(var, TensorType::NumberType())
  2290. .ATTR(use_locking, Bool, false)
  2291. .OP_END_FACTORY_REG(SparseApplyAdadelta)
  2292. /**
  2293. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2294. * a const input will be considered as an attribute.\n
  2295. * accum <- rho * accum + (1 - rho) * grad.square()\n
  2296. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n
  2297. * var <- var - update * lr\n
  2298. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n
  2299. *
  2300. * @par Inputs:
  2301. * Seven inputs, including:
  2302. * @li var: A mutable tensor. Must be one of the data types defined in
  2303. * TensorType::NumberType(). Should be from a Variable().
  2304. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2305. * Variable().
  2306. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2307. * Variable().
  2308. * @li lr: A scalar. Must have the same type as "var".
  2309. * @li rho: A scalar. Must have the same type as "var".
  2310. * @li grad: A tensor, specifying the gradient.
  2311. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2312. *
  2313. * @par Attributes:
  2314. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2315. * updating of the "var", "accum", and "accum_update" tensors will be protected by a lock;
  2316. * otherwise the behavior is undefined, but may exhibit less contention.
  2317. * @li epsilon: A required scalar. Must have the same type as "var".
  2318. *
  2319. * @par Outputs:
  2320. * @li var: A mutable tensor. Must have the same type as input "var".
  2321. * @li accum: A mutable tensor. Must have the same type as input "accum".
  2322. * @li accum_update: A mutable tensor. Must have the same type as input "accum_update".
  2323. *
  2324. * @attention Constraints:
  2325. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2326. * in iterations during which "grad" is 0.
  2327. * @li The input tensors "var", "accum" and "accum_update" must have the same shape.
  2328. */
  2329. REG_OP(SparseApplyAdadeltaD)
  2330. .INPUT(var, TensorType::NumberType())
  2331. .INPUT(accum, TensorType::NumberType())
  2332. .INPUT(accum_update, TensorType::NumberType())
  2333. .INPUT(lr, TensorType::NumberType())
  2334. .INPUT(rho, TensorType::NumberType())
  2335. .INPUT(grad, TensorType::NumberType())
  2336. .INPUT(indices, TensorType::IndexNumberType())
  2337. .OUTPUT(var, TensorType::NumberType())
  2338. .OUTPUT(accum, TensorType::NumberType())
  2339. .OUTPUT(accum_update, TensorType::NumberType())
  2340. .REQUIRED_ATTR(epsilon, Float)
  2341. .ATTR(use_locking, Bool, false)
  2342. .OP_END_FACTORY_REG(SparseApplyAdadeltaD)
  2343. /**
  2344. *@brief Clean memory of workspace list.
  2345. *@par Attributes:
  2346. * @li automic_add_mem_size: sizes of workspaces.
  2347. */
  2348. REG_OP(AtomicAddrClean)
  2349. .ATTR(automic_add_mem_size, ListInt, {})
  2350. .OP_END_FACTORY_REG(AtomicAddrClean)
  2351. } // namespace ge
  2352. #endif // GE_OP_TRAINING_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示