You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 95 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_TRAINING_OPS_H
  17. #define GE_OP_TRAINING_OPS_H
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Updates "var" according to the AdaMax algorithm.\n
  22. * t-1 mean previous period.
  23. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  24. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  25. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  26. *
  27. *@attention Constraints:\n
  28. * the input tensors must have the same shape.
  29. *
  30. *@par Inputs:
  31. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  32. * Should be from a Variable().
  33. *@li m: A mutable tensor. Has the same type as "var".
  34. * Should be from a Variable().
  35. *@li v: A mutable tensor. Has the same type as "var".
  36. * Should be from a Variable().
  37. *@li beta1_power: A scalar. Has the same type as "var".
  38. *@li lr: learning_rate. A scalar. Has the same type as "var".
  39. *@li beta1: A scalar. Has the same type as "var".
  40. *@li beta2: A scalar. Has the same type as "var".
  41. *@li epsilon: A scalar. Has the same type as "var".
  42. *@li grad: A tensor for the gradient. Has the same type as "var".
  43. *
  44. *@par Attributes:\n
  45. * use_locking: An optional bool. Defaults to "False".
  46. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  47. * by a lock; otherwise the behavior is undefined, but may exhibit less
  48. * contention.
  49. *
  50. *@par Outputs:
  51. * var: A mutable tensor. Has the same type as input "var".
  52. *
  53. */
  54. REG_OP(ApplyAdaMax)
  55. .INPUT(var, TensorType::NumberType())
  56. .INPUT(m, TensorType::NumberType())
  57. .INPUT(v, TensorType::NumberType())
  58. .INPUT(beta1_power, TensorType::NumberType())
  59. .INPUT(lr, TensorType::NumberType())
  60. .INPUT(beta1, TensorType::NumberType())
  61. .INPUT(beta2, TensorType::NumberType())
  62. .INPUT(epsilon, TensorType::NumberType())
  63. .INPUT(grad, TensorType::NumberType())
  64. .OUTPUT(var, TensorType::NumberType())
  65. .ATTR(use_locking, Bool, false)
  66. .OP_END_FACTORY_REG(ApplyAdaMax)
  67. /**
  68. *@brief Updates "var" according to the AdaMax algorithm.\n
  69. * t-1 mean previous period.
  70. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  71. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  72. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  73. *
  74. *@attention Constraints:\n
  75. * the input tensors must have the same shape.
  76. *
  77. *@par Inputs:
  78. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  79. * Should be from a Variable().
  80. *@li m: A mutable tensor. Has the same type as "var".
  81. * Should be from a Variable().
  82. *@li v: A mutable tensor. Has the same type as "var".
  83. * Should be from a Variable().
  84. *@li beta1_power: A scalar. Has the same type as "var".
  85. *@li lr: learning_rate. A scalar. Has the same type as "var".
  86. *@li beta1: A scalar. Has the same type as "var".
  87. *@li beta2: A scalar. Has the same type as "var".
  88. *@li epsilon: A scalar. Has the same type as "var".
  89. *@li grad: A tensor for the gradient. Has the same type as "var".
  90. *
  91. *@par Attributes:\n
  92. * use_locking: An optional bool. Defaults to "False".
  93. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  94. * by a lock; otherwise the behavior is undefined, but may exhibit less
  95. * contention.
  96. *
  97. *@par Outputs:
  98. * var: A mutable tensor. Has the same type as input "var".
  99. *
  100. *
  101. */
  102. REG_OP(ApplyAdaMaxD)
  103. .INPUT(var, TensorType::NumberType())
  104. .INPUT(m, TensorType::NumberType())
  105. .INPUT(v, TensorType::NumberType())
  106. .INPUT(beta1_power, TensorType::NumberType())
  107. .INPUT(lr, TensorType::NumberType())
  108. .INPUT(beta1, TensorType::NumberType())
  109. .INPUT(beta2, TensorType::NumberType())
  110. .INPUT(epsilon, TensorType::NumberType())
  111. .INPUT(grad, TensorType::NumberType())
  112. .OUTPUT(var, TensorType::NumberType())
  113. .OUTPUT(m, TensorType::NumberType())
  114. .OUTPUT(v, TensorType::NumberType())
  115. .ATTR(use_locking, Bool, false)
  116. .OP_END_FACTORY_REG(ApplyAdaMaxD)
  117. /**
  118. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  119. *@par Inputs:
  120. * Five inputs, including:
  121. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  122. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  123. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  124. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  125. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  126. *@par Attributes:
  127. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  128. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  129. *@par Outputs:
  130. *var: A Tensor. Has the same type and format as input "var".
  131. */
  132. REG_OP(SparseApplyAdagrad)
  133. .INPUT(var, TensorType({DT_FLOAT}))
  134. .INPUT(accum, TensorType({DT_FLOAT}))
  135. .INPUT(lr, TensorType({DT_FLOAT}))
  136. .INPUT(grad, TensorType({DT_FLOAT}))
  137. .INPUT(indices, TensorType({DT_INT32}))
  138. .OUTPUT(var, TensorType({DT_FLOAT}))
  139. .ATTR(use_locking, Bool, false)
  140. .ATTR(update_slots, Bool, true)
  141. .OP_END_FACTORY_REG(SparseApplyAdagrad)
  142. /**
  143. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  144. *@par Inputs:
  145. * Four inputs, including:
  146. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  147. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  148. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  149. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  150. *@par Attributes:
  151. *@li lr: Required, used for computation.
  152. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  153. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  154. *@par Outputs:
  155. *@li var: A Tensor. Has the same type and format as input "var".
  156. *@li accum: A Tensor. Has the same type and format as input "var".
  157. */
  158. REG_OP(SparseApplyAdagradD)
  159. .INPUT(var, TensorType({DT_FLOAT}))
  160. .INPUT(accum, TensorType({DT_FLOAT}))
  161. .INPUT(grad, TensorType({DT_FLOAT}))
  162. .INPUT(indices, TensorType({DT_INT32}))
  163. .OUTPUT(var, TensorType({DT_FLOAT}))
  164. .REQUIRED_ATTR(lr, Float)
  165. .ATTR(use_locking, Bool, false)
  166. .ATTR(update_slots, Bool, true)
  167. .OP_END_FACTORY_REG(SparseApplyAdagradD)
  168. /**
  169. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  170. *@par Inputs:
  171. *Six inputs, including:
  172. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  173. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  174. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  175. *@li epsilon: An NCHW, NHWC, or ND Tensor of type float32.
  176. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  177. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  178. *@par Attributes:
  179. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  180. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different.
  181. *@par Outputs:
  182. *var: A Tensor. Has the same type and format as input "var".
  183. */
  184. REG_OP(SparseApplyAdagradV2)
  185. .INPUT(var, TensorType({DT_FLOAT}))
  186. .INPUT(accum, TensorType({DT_FLOAT}))
  187. .INPUT(lr, TensorType({DT_FLOAT}))
  188. .INPUT(epsilon, TensorType({DT_FLOAT}))
  189. .INPUT(grad, TensorType({DT_FLOAT}))
  190. .INPUT(indices, TensorType({DT_INT32}))
  191. .OUTPUT(var, TensorType({DT_FLOAT}))
  192. .ATTR(use_locking, Bool, false)
  193. .ATTR(update_slots, Bool, true)
  194. .OP_END_FACTORY_REG(SparseApplyAdagradV2)
  195. /**
  196. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  197. *@par Inputs:
  198. *Four inputs, including:
  199. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  200. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  201. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  202. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  203. *@par Attributes:
  204. *@li lr: Required, used for computation.
  205. *@li epsilon: Required, used for computation.
  206. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  207. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different.
  208. *@par Outputs:
  209. *@li var: A Tensor. Has the same type and format as input "var".
  210. *@li accum: A Tensor. Has the same type and format as input "accum".
  211. */
  212. REG_OP(SparseApplyAdagradV2D)
  213. .INPUT(var, TensorType({DT_FLOAT}))
  214. .INPUT(accum, TensorType({DT_FLOAT}))
  215. .INPUT(grad, TensorType({DT_FLOAT}))
  216. .INPUT(indices, TensorType({DT_INT32}))
  217. .OUTPUT(var, TensorType({DT_FLOAT}))
  218. .OUTPUT(accum, TensorType({DT_FLOAT}))
  219. .REQUIRED_ATTR(lr, Float)
  220. .REQUIRED_ATTR(epsilon, Float)
  221. .ATTR(use_locking, Bool, false)
  222. .ATTR(update_slots, Bool, true)
  223. .OP_END_FACTORY_REG(SparseApplyAdagradV2D)
  224. /**
  225. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  226. * want to use Nesterov momentum.\n
  227. * computing process: \n
  228. * accum = accum * momentum + grad\n
  229. * var -= lr * accum
  230. *
  231. *@attention Constraints:\n
  232. * the input tensors must have the same shape.
  233. *
  234. *@par Inputs:
  235. *@li var: A mutable tensor. Should be from a Variable().
  236. *@li accum: A mutable tensor. Has the same type as "var".
  237. * Should be from a Variable().
  238. *@li lr: A scalar. Has the same type as "var".
  239. *@li grad: A tensor for the gradient. Has the same type as "var".
  240. *
  241. *@par Attributes:
  242. *@li use_nesterov: An optional bool. Defaults to "False".
  243. * If "True", the tensor passed to compute grad will be
  244. * var - lr * momentum * accum, so in the end, the var you get is actually
  245. * var - lr * momentum * accum.
  246. *
  247. *@li use_locking: An optional bool. Defaults to "False".\n
  248. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  249. * otherwise the behavior is undefined, but may exhibit less contention.
  250. *
  251. *@par Outputs:
  252. * var: A mutable tensor. Has the same type as input "var".
  253. *
  254. */
  255. REG_OP(ApplyMomentum)
  256. .INPUT(var, TensorType::NumberType())
  257. .INPUT(accum, TensorType::NumberType())
  258. .INPUT(lr, TensorType::NumberType())
  259. .INPUT(grad, TensorType::NumberType())
  260. .INPUT(momentum, TensorType::NumberType())
  261. .OUTPUT(var, TensorType::NumberType())
  262. .ATTR(use_nesterov, Bool, false)
  263. .ATTR(use_locking, Bool, false)
  264. .OP_END_FACTORY_REG(ApplyMomentum)
  265. REG_OP(ApplyMomentumCCE)
  266. .INPUT(var, TensorType::NumberType())
  267. .INPUT(accum, TensorType::NumberType())
  268. .INPUT(lr, TensorType::NumberType())
  269. .INPUT(grad, TensorType::NumberType())
  270. .INPUT(momentum, TensorType::NumberType())
  271. .OUTPUT(var, TensorType::NumberType())
  272. .ATTR(use_nesterov, Bool, false)
  273. .ATTR(use_locking, Bool, false)
  274. .OP_END_FACTORY_REG(ApplyMomentumCCE)
  275. /**
  276. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  277. * want to use Nesterov momentum.\n
  278. * computing process: \n
  279. * accum = accum * momentum + grad\n
  280. * var -= lr * accum
  281. *
  282. *@attention Constraints:\n
  283. * the input tensors must have the same shape.
  284. *
  285. *@par Inputs:
  286. *@li var: A mutable tensor. Should be from a Variable().
  287. *@li accum: A mutable tensor. Has the same type as "var".
  288. * Should be from a Variable().
  289. *@li lr: A scalar. Has the same type as "var".
  290. *@li grad: A tensor for the gradient. Has the same type as "var".
  291. *
  292. *@par Attributes:
  293. *@li use_nesterov: An optional bool. Defaults to "False".
  294. * If "True", the tensor passed to compute grad will be
  295. * var - lr * momentum * accum, so in the end, the var you get is actually
  296. * var - lr * momentum * accum.
  297. *
  298. *@li use_locking: An optional bool. Defaults to "False".\n
  299. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  300. * otherwise the behavior is undefined, but may exhibit less contention.
  301. *
  302. *@par Outputs:
  303. * var: A mutable tensor. Has the same type as input "var".
  304. * accum: A mutable tensor. Has the same type as input "accum".
  305. *
  306. */
  307. REG_OP(ApplyMomentumD)
  308. .INPUT(var, TensorType::NumberType())
  309. .INPUT(accum, TensorType::NumberType())
  310. .INPUT(lr, TensorType::NumberType())
  311. .INPUT(grad, TensorType::NumberType())
  312. .INPUT(momentum, TensorType::NumberType())
  313. .OUTPUT(var, TensorType::NumberType())
  314. .OUTPUT(accum, TensorType::NumberType())
  315. .ATTR(use_nesterov, Bool, false)
  316. .ATTR(use_locking, Bool, false)
  317. .OP_END_FACTORY_REG(ApplyMomentumD)
  318. /**
  319. *@brief Updates '*var' according to the momentum scheme.
  320. * accum = accum * momentum - grad * lr \n
  321. * if use_nesterov is True: \n
  322. * var += accum * momentum - grad * lr \n
  323. * else: \n
  324. * var += accum
  325. *
  326. *@par Inputs:
  327. *@li var: A mutable tensor. Must be one of the data types defined in
  328. * TensorType::NumberType(). Should be from a Variable().
  329. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  330. * Variable().
  331. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  332. * from a Variable().
  333. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  334. * from a Variable().
  335. *@li momentum: A scalar. Has the same type as "var".
  336. *
  337. *@par Attributes:
  338. *@li use_nesterov: An optional bool. Defaults to "False".
  339. * If "True", var will be updated by using Nesterov momentum.
  340. *@li use_locking: An optional bool. Defaults to "False".
  341. * If "True", updating of the "var" tensor is protected by a lock;
  342. * otherwise the behavior is undefined, but may exhibit less contention.
  343. *
  344. *@par Outputs:
  345. * var: A mutable tensor. Has the same type as input "var".
  346. *
  347. *@attention Constraints:
  348. * The input tensors must have the same shape.
  349. *
  350. *
  351. */
  352. REG_OP(ApplyKerasMomentum)
  353. .INPUT(var, TensorType::NumberType())
  354. .INPUT(accum, TensorType::NumberType())
  355. .INPUT(lr, TensorType::NumberType())
  356. .INPUT(grad, TensorType::NumberType())
  357. .INPUT(momentum, TensorType::NumberType())
  358. .OUTPUT(var, TensorType::NumberType())
  359. .ATTR(use_locking, Bool, false)
  360. .ATTR(use_nesterov, Bool, false)
  361. .OP_END_FACTORY_REG(ApplyKerasMomentum)
  362. /**
  363. *@brief Updates '*var' according to the momentum scheme.
  364. * accum = accum * momentum - grad * lr \n
  365. * if use_nesterov is True: \n
  366. * var += accum * momentum - grad * lr \n
  367. * else: \n
  368. * var += accum
  369. *
  370. *@par Inputs:
  371. *@li var: A mutable tensor. Must be one of the data types defined in
  372. * TensorType::NumberType(). Should be from a Variable().
  373. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  374. * Variable().
  375. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  376. * from a Variable().
  377. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  378. * from a Variable().
  379. *@li momentum: A scalar. Has the same type as "var". Should be from a
  380. * Variable().
  381. *
  382. *@par Attributes:
  383. *@li use_nesterov: An optional bool. Defaults to "False".
  384. * If "True", var will be updated by using nesterov momentum
  385. *@li use_locking: An optional bool. Defaults to "False".
  386. * If "True", updating of the "var" tensor is protected by a lock;
  387. * otherwise the behavior is undefined, but may exhibit less contention.
  388. *
  389. *@par Outputs:
  390. *@li var: A mutable tensor. Has the same type as input "var".
  391. *@li accum: A mutable tensor. Has the same type as input "var"
  392. *
  393. *@attention Constraints:
  394. * The input tensors must have the same shape.
  395. *
  396. *
  397. */
  398. REG_OP(ApplyKerasMomentumD)
  399. .INPUT(var, TensorType::NumberType())
  400. .INPUT(accum, TensorType::NumberType())
  401. .INPUT(lr, TensorType::NumberType())
  402. .INPUT(grad, TensorType::NumberType())
  403. .INPUT(momentum, TensorType::NumberType())
  404. .OUTPUT(var, TensorType::NumberType())
  405. .OUTPUT(accum, TensorType::NumberType())
  406. .ATTR(use_locking, Bool, false)
  407. .ATTR(use_nesterov, Bool, false)
  408. .OP_END_FACTORY_REG(ApplyKerasMomentumD)
  409. /**
  410. *@brief Updates '*var' according to the Adam algorithm..
  411. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  412. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  413. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  414. * vhat_t := max{vhat_{t-1}, v_t}
  415. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  416. *
  417. *@par Inputs:
  418. *@li var: A mutable tensor. Must be one of the data types defined in
  419. * TensorType::NumberType(). Should be from a Variable().
  420. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  421. * Variable().
  422. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  423. * Variable().
  424. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  425. * Variable().
  426. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  427. * Variable().
  428. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  429. * Variable().
  430. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  431. * from a Variable().
  432. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  433. * from a Variable().
  434. *
  435. *@par Attributes:
  436. *@li beta1: A scalar. Has the same type as "var".
  437. *@li beta2: A scalar. Has the same type as "var".
  438. *@li epsilon: A scalar. Has the same type as "var".
  439. *@li use_locking: An optional bool. Defaults to "False".
  440. * If "True", updating of the "var" tensor is protected by a lock;
  441. * otherwise the behavior is undefined, but may exhibit less contention.
  442. *
  443. *@par Outputs:
  444. *@li var: A mutable tensor. Has the same type as input "var".
  445. *@li m: A mutable tensor. Has the same type as input "var"
  446. *@li v: A mutable tensor. Has the same type as input "var"
  447. *@li vhat: A mutable tensor. Has the same type as input "var"
  448. *
  449. *@attention Constraints:
  450. * The input tensors must have the same shape.
  451. *
  452. *
  453. */
  454. REG_OP(ApplyAdamWithAmsgradD)
  455. .INPUT(var, TensorType::NumberType())
  456. .INPUT(m, TensorType::NumberType())
  457. .INPUT(v, TensorType::NumberType())
  458. .INPUT(vhat, TensorType::NumberType())
  459. .INPUT(beta1_power, TensorType::NumberType())
  460. .INPUT(beta2_power, TensorType::NumberType())
  461. .INPUT(lr, TensorType::NumberType())
  462. .INPUT(grad, TensorType::NumberType())
  463. .OUTPUT(var, TensorType::NumberType())
  464. .OUTPUT(m, TensorType::NumberType())
  465. .OUTPUT(v, TensorType::NumberType())
  466. .OUTPUT(vhat, TensorType::NumberType())
  467. .REQUIRED_ATTR(beta1, Float)
  468. .REQUIRED_ATTR(beta2, Float)
  469. .REQUIRED_ATTR(epsilon, Float)
  470. .ATTR(use_locking, Bool, false)
  471. .OP_END_FACTORY_REG(ApplyAdamWithAmsgradD)
  472. /**
  473. *@brief Updates '*var' according to the Adam algorithm..
  474. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  475. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  476. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  477. * vhat_t := max{vhat_{t-1}, v_t}
  478. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  479. *
  480. *@par Inputs:
  481. *@li var: A mutable tensor. Must be one of the data types defined in
  482. * TensorType::NumberType(). Should be from a Variable().
  483. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  484. * Variable().
  485. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  486. * Variable().
  487. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  488. * Variable().
  489. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  490. * Variable().
  491. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  492. * Variable().
  493. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  494. * from a Variable().
  495. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  496. * from a Variable().
  497. *
  498. *@par Attributes:
  499. *@li beta1: A scalar. Has the same type as "var".
  500. *@li beta2: A scalar. Has the same type as "var".
  501. *@li epsilon: A scalar. Has the same type as "var".
  502. *@li use_locking: An optional bool. Defaults to "False".
  503. * If "True", updating of the "var" tensor is protected by a lock;
  504. * otherwise the behavior is undefined, but may exhibit less contention.
  505. *
  506. *@par Outputs:
  507. *@li var: A mutable tensor. Has the same type as input "var".
  508. *@li m: A mutable tensor. Has the same type as input "var"
  509. *@li v: A mutable tensor. Has the same type as input "var"
  510. *@li vhat: A mutable tensor. Has the same type as input "var"
  511. *
  512. *@attention Constraints:
  513. * The input tensors must have the same shape.
  514. *
  515. *
  516. */
  517. REG_OP(ApplyAdamWithAmsgrad)
  518. .INPUT(var, TensorType::NumberType())
  519. .INPUT(m, TensorType::NumberType())
  520. .INPUT(v, TensorType::NumberType())
  521. .INPUT(vhat, TensorType::NumberType())
  522. .INPUT(beta1_power, TensorType::NumberType())
  523. .INPUT(beta2_power, TensorType::NumberType())
  524. .INPUT(lr, TensorType::NumberType())
  525. .INPUT(beta1, TensorType::NumberType())
  526. .INPUT(beta2, TensorType::NumberType())
  527. .INPUT(epsilon, TensorType::NumberType())
  528. .INPUT(grad, TensorType::NumberType())
  529. .OUTPUT(var, TensorType::NumberType())
  530. .ATTR(use_locking, Bool, false)
  531. .OP_END_FACTORY_REG(ApplyAdamWithAmsgrad)
  532. /**
  533. *@brief Updates "var" according to the AddSign update.\n
  534. * t-1 mean previous period.
  535. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  536. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  537. * var <- var - lr * update
  538. *
  539. *@attention Constraints:\n
  540. * the input tensors must have the same shape.
  541. *
  542. *@par Inputs:
  543. *@li var: A mutable tensor. Should be from a Variable().
  544. *@li m: A mutable tensor. Has the same type as "var".
  545. * Should be from a Variable().
  546. *@li lr: A scalar. Has the same type as "var".
  547. *@li logbase: A scalar. Has the same type as "var".
  548. *@li sign_decay: A scalar. Has the same type as "var".
  549. *@li beta: A scalar. Has the same type as "var".
  550. *@li grad: A tensor for the gradient. Has the same type as "var".
  551. *
  552. *@par Attributes:
  553. * use_locking: An optional bool. Defaults to "False".
  554. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  555. * by a lock; otherwise the behavior is undefined, but may exhibit less
  556. * contention.
  557. *
  558. *@par Outputs:
  559. * var: A mutable tensor. Has the same type as input "var".
  560. *
  561. */
  562. REG_OP(ApplyPowerSign)
  563. .INPUT(var, TensorType::NumberType())
  564. .INPUT(m, TensorType::NumberType())
  565. .INPUT(lr, TensorType::NumberType())
  566. .INPUT(logbase, TensorType::NumberType())
  567. .INPUT(sign_decay, TensorType::NumberType())
  568. .INPUT(beta, TensorType::NumberType())
  569. .INPUT(grad, TensorType::NumberType())
  570. .OUTPUT(var, TensorType::NumberType())
  571. .ATTR(use_locking, Bool, false)
  572. .OP_END_FACTORY_REG(ApplyPowerSign)
  573. /**
  574. *@brief Updates "var" according to the AddSign update.\n
  575. * t-1 mean previous period.
  576. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  577. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  578. * var <- var - lr * update
  579. *
  580. *@attention Constraints:\n
  581. * the input tensors must have the same shape.
  582. *
  583. *@par Inputs:
  584. *@li var: A mutable tensor. Should be from a Variable().
  585. *@li m: A mutable tensor. Has the same type as "var".
  586. * Should be from a Variable().
  587. *@li lr: A scalar. Has the same type as "var".
  588. *@li logbase: A scalar. Has the same type as "var".
  589. *@li sign_decay: A scalar. Has the same type as "var".
  590. *@li beta: A scalar. Has the same type as "var".
  591. *@li grad: A tensor for the gradient. Has the same type as "var".
  592. *
  593. *@par Attributes:
  594. * use_locking: An optional bool. Defaults to "False".
  595. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  596. * by a lock; otherwise the behavior is undefined, but may exhibit less
  597. * contention.
  598. *
  599. *@par Outputs:
  600. *@li var: A mutable tensor. Has the same type as input "var".
  601. *@li m: A mutable tensor. Has the same type as input "var".
  602. *
  603. *
  604. */
  605. REG_OP(ApplyPowerSignD)
  606. .INPUT(var, TensorType::NumberType())
  607. .INPUT(m, TensorType::NumberType())
  608. .INPUT(lr, TensorType::NumberType())
  609. .INPUT(logbase, TensorType::NumberType())
  610. .INPUT(sign_decay, TensorType::NumberType())
  611. .INPUT(beta, TensorType::NumberType())
  612. .INPUT(grad, TensorType::NumberType())
  613. .OUTPUT(var, TensorType::NumberType())
  614. .OUTPUT(m, TensorType::NumberType())
  615. .ATTR(use_locking, Bool, false)
  616. .OP_END_FACTORY_REG(ApplyPowerSignD)
  617. /**
  618. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n
  619. * prox_v = var - alpha * delta\n
  620. * var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
  621. *
  622. *@attention Constraints:\n
  623. * the input tensors must have the same shape.
  624. *
  625. *@par Inputs:
  626. *@li var: A mutable tensor. Should be from a Variable().
  627. *@li alpha: A scalar. Has the same type as "var".
  628. *@li l1: A scalar. Has the same type as "var".
  629. *@li l2: A scalar. Has the same type as "var".
  630. *@li delta: A tensor. Has the same type as "var". The change.
  631. *
  632. *@par Attributes:
  633. * use_locking: An optional bool. Defaults to "False".
  634. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  635. * by a lock; otherwise the behavior is undefined, but may exhibit less
  636. * contention.
  637. *
  638. *@par Outputs:
  639. * var: A mutable tensor. Has the same type as input "var".
  640. *
  641. */
  642. REG_OP(ApplyProximalGradientDescent)
  643. .INPUT(var, TensorType::NumberType())
  644. .INPUT(alpha, TensorType::NumberType())
  645. .INPUT(l1, TensorType::NumberType())
  646. .INPUT(l2, TensorType::NumberType())
  647. .INPUT(delta, TensorType::NumberType())
  648. .OUTPUT(var, TensorType::NumberType())
  649. .ATTR(use_locking, Bool, false)
  650. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  651. /**
  652. *@brief Updates "var" according to the AddSign update.
  653. *@par Inputs:
  654. *Seven inputs, including:
  655. * @li var: A mutable Tensor of type TensorType::NumberType().
  656. * Should be a Variable Tensor.
  657. * @li m: A mutable Tensor of the same type as "var".
  658. * Should be a Variable Tensor.
  659. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  660. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  661. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  662. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  663. * @li grad: A Tensor of the same type as "var", for the gradient.
  664. *@par Attributes:
  665. *use_locking: An optional bool. Defaults to "False".
  666. * If "True", updating of the "var" and "m" tensors will be
  667. * protected by a lock; otherwise the behavior is undefined,
  668. * but may exhibit less contention.
  669. *@par Outputs:
  670. *var: A mutable Tensor. Has the same type as "var".
  671. */
  672. REG_OP(ApplyAddSign)
  673. .INPUT(var, TensorType::NumberType())
  674. .INPUT(m, TensorType::NumberType())
  675. .INPUT(lr, TensorType::NumberType())
  676. .INPUT(alpha, TensorType::NumberType())
  677. .INPUT(sign_decay, TensorType::NumberType())
  678. .INPUT(beta, TensorType::NumberType())
  679. .INPUT(grad, TensorType::NumberType())
  680. .OUTPUT(var, TensorType::NumberType())
  681. .ATTR(use_locking, Bool, false)
  682. .OP_END_FACTORY_REG(ApplyAddSign)
  683. /**
  684. *@brief Updates "var" according to the AddSign update.
  685. *@par Inputs:
  686. *Seven inputs, including:
  687. * @li var: A mutable Tensor of type TensorType::NumberType().
  688. * Should be a Variable Tensor.
  689. * @li m: A mutable Tensor of the same type as "var".
  690. * Should be a Variable Tensor.
  691. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  692. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  693. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  694. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  695. * @li grad: A Tensor of the same type as "var", for the gradient.
  696. *@par Attributes:
  697. *use_locking: An optional bool. Defaults to "False".
  698. * If "True", updating of the "var" and "m" tensors will be
  699. * protected by a lock; otherwise the behavior is undefined,
  700. * but may exhibit less contention.
  701. *@par Outputs:
  702. *@li var: A mutable Tensor. Has the same type as "var".
  703. *@li m: A mutable Tensor. Has the same type as "m".
  704. */
  705. REG_OP(ApplyAddSignD)
  706. .INPUT(var, TensorType::NumberType())
  707. .INPUT(m, TensorType::NumberType())
  708. .INPUT(lr, TensorType::NumberType())
  709. .INPUT(alpha, TensorType::NumberType())
  710. .INPUT(sign_decay, TensorType::NumberType())
  711. .INPUT(beta, TensorType::NumberType())
  712. .INPUT(grad, TensorType::NumberType())
  713. .OUTPUT(var, TensorType::NumberType())
  714. .OUTPUT(m, TensorType::NumberType())
  715. .ATTR(use_locking, Bool, false)
  716. .OP_END_FACTORY_REG(ApplyAddSignD)
  717. /**
  718. *@brief Updates "var" according to the centered RMSProp algorithm.\n
  719. * The centered RMSProp algorithm uses an estimate of the centered second moment
  720. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  721. * uses the (uncentered) second moment. This often helps with training, but is
  722. * slightly more expensive in terms of computation and memory.
  723. *
  724. * t-1 mean previous period.
  725. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  726. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  727. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  728. * var <- var - mom\n
  729. *
  730. *@attention Constraints:\n
  731. *@li in dense implementation of this algorithm, mg, ms, and mom will
  732. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  733. * and mom will not update in iterations during which the grad is zero.
  734. *@li the input tensors must have the same shape.
  735. *
  736. *@par Inputs:
  737. *@li var: A mutable tensor. Should be from a Variable().
  738. *@li mg: A mutable tensor. Has the same type as "var".
  739. * Should be from a Variable().
  740. *@li ms: A mutable tensor. Has the same type as "var".
  741. * Should be from a Variable().
  742. *@li mom: A mutable tensor. Has the same type as "var".
  743. * Should be from a Variable().
  744. *@li lr: A scalar. Has the same type as "var".
  745. *@li rho: A scalar. Has the same type as "var".
  746. *@li momentum: A tensor. Has the same type as "var".
  747. *@li epsilon: A scalar. Has the same type as "var".
  748. *@li grad: A tensor for the gradient. Has the same type as "var".
  749. *
  750. *@par Attributes:
  751. * use_locking: An optional bool. Defaults to "False".
  752. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  753. * by a lock; otherwise the behavior is undefined, but may exhibit less
  754. * contention.
  755. *
  756. *@par Outputs:
  757. * var: A mutable tensor. Has the same type as input "var".
  758. *
  759. */
  760. REG_OP(ApplyCenteredRMSProp)
  761. .INPUT(var, TensorType::NumberType())
  762. .INPUT(mg, TensorType::NumberType())
  763. .INPUT(ms, TensorType::NumberType())
  764. .INPUT(mom, TensorType::NumberType())
  765. .INPUT(lr, TensorType::NumberType())
  766. .INPUT(rho, TensorType::NumberType())
  767. .INPUT(momentum, TensorType::NumberType())
  768. .INPUT(epsilon, TensorType::NumberType())
  769. .INPUT(grad, TensorType::NumberType())
  770. .OUTPUT(var, TensorType::NumberType())
  771. .ATTR(use_locking, Bool, false)
  772. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  773. /**
  774. *@brief Updates "var" according to the centered RMSProp algorithm.\n
  775. * The centered RMSProp algorithm uses an estimate of the centered second moment
  776. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  777. * uses the (uncentered) second moment. This often helps with training, but is
  778. * slightly more expensive in terms of computation and memory.
  779. *
  780. * t-1 mean previous period.
  781. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  782. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  783. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  784. * var <- var - mom\n
  785. *
  786. *@attention Constraints:\n
  787. *@li in dense implementation of this algorithm, mg, ms, and mom will
  788. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  789. * and mom will not update in iterations during which the grad is zero.
  790. *@li the input tensors must have the same shape.
  791. *
  792. *@par Inputs:
  793. *@li var: A mutable tensor. Should be from a Variable().
  794. *@li mg: A mutable tensor. Has the same type as "var".
  795. * Should be from a Variable().
  796. *@li ms: A mutable tensor. Has the same type as "var".
  797. * Should be from a Variable().
  798. *@li mom: A mutable tensor. Has the same type as "var".
  799. * Should be from a Variable().
  800. *@li lr: A scalar. Has the same type as "var".
  801. *@li rho: A scalar. Has the same type as "var".
  802. *@li momentum: A tensor. Has the same type as "var".
  803. *@li epsilon: A scalar. Has the same type as "var".
  804. *@li grad: A tensor for the gradient. Has the same type as "var".
  805. *
  806. *@par Attributes:
  807. * use_locking: An optional bool. Defaults to "False".
  808. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  809. * by a lock; otherwise the behavior is undefined, but may exhibit less
  810. * contention.
  811. *
  812. *@par Outputs:
  813. *@li var: A mutable Tensor. Has the same type as "var".
  814. *@li mg: A mutable Tensor. Has the same type as "mg".
  815. *@li ms: A mutable Tensor. Has the same type as "ms".
  816. *@li mom: A mutable Tensor. Has the same type as "mom".
  817. *
  818. */
  819. REG_OP(ApplyCenteredRMSPropD)
  820. .INPUT(var, TensorType::NumberType())
  821. .INPUT(mg, TensorType::NumberType())
  822. .INPUT(ms, TensorType::NumberType())
  823. .INPUT(mom, TensorType::NumberType())
  824. .INPUT(lr, TensorType::NumberType())
  825. .INPUT(rho, TensorType::NumberType())
  826. .INPUT(momentum, TensorType::NumberType())
  827. .INPUT(epsilon, TensorType::NumberType())
  828. .INPUT(grad, TensorType::NumberType())
  829. .OUTPUT(var, TensorType::NumberType())
  830. .OUTPUT(mg, TensorType::NumberType())
  831. .OUTPUT(ms, TensorType::NumberType())
  832. .OUTPUT(mom, TensorType::NumberType())
  833. .ATTR(use_locking, Bool, false)
  834. .OP_END_FACTORY_REG(ApplyCenteredRMSPropD)
  835. /**
  836. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.\n
  837. * var -= delta * alpha
  838. *
  839. *@attention Constraints:\n
  840. * the input tensors must have the same shape.
  841. *
  842. *@par Inputs:
  843. *@li var: A mutable tensor. Should be from a Variable().
  844. *@li alpha: A scalar. Has the same type as "var".
  845. *@li delta: A tensor for the change. Has the same type as "var".
  846. *
  847. *@par Attributes:
  848. * use_locking: An optional bool. Defaults to "False".
  849. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  850. * by a lock; otherwise the behavior is undefined, but may exhibit less
  851. * contention.
  852. *
  853. *@par Outputs:
  854. * var: A mutable tensor. Has the same type as input "var".
  855. *
  856. */
  857. REG_OP(ApplyGradientDescent)
  858. .INPUT(var, TensorType::NumberType())
  859. .INPUT(alpha, TensorType::NumberType())
  860. .INPUT(delta, TensorType::NumberType())
  861. .OUTPUT(var, TensorType::NumberType())
  862. .ATTR(use_locking, Bool, false)
  863. .OP_END_FACTORY_REG(ApplyGradientDescent)
  864. /**
  865. *@brief Updates "var" according to the adagrad scheme.\n
  866. * accum += grad * grad\n
  867. * var -= lr * grad * (1 / sqrt(accum))
  868. *
  869. *@attention Constraints:\n
  870. * the input tensors must have the same shape.
  871. *
  872. *@par Inputs:
  873. *@li var: A mutable tensor. Should be from a Variable().
  874. *@li accum: A mutable tensor. Has the same type as "var".
  875. * Should be from a Variable().
  876. *@li lr: A scalar. Has the same type as "var".
  877. *@li grad: A tensor for the gradient. Has the same type as "var".
  878. *
  879. *@par Attributes:
  880. * use_locking: An optional bool. Defaults to "False".
  881. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  882. * by a lock; otherwise the behavior is undefined, but may exhibit less
  883. * contention.
  884. *
  885. *@par Outputs:
  886. * var: A mutable tensor. Has the same type as input "var".
  887. *
  888. */
  889. REG_OP(ApplyAdagrad)
  890. .INPUT(var, TensorType::NumberType())
  891. .INPUT(accum, TensorType::NumberType())
  892. .INPUT(lr, TensorType::NumberType())
  893. .INPUT(grad, TensorType::NumberType())
  894. .OUTPUT(var, TensorType::NumberType())
  895. .ATTR(update_slots, Bool, true)
  896. .ATTR(use_locking, Bool, false)
  897. .OP_END_FACTORY_REG(ApplyAdagrad)
  898. /**
  899. *@brief Updates "var" according to the adagrad scheme.\n
  900. * accum += grad * grad\n
  901. * var -= lr * grad * (1 / sqrt(accum))
  902. *
  903. *@attention Constraints:\n
  904. * the input tensors must have the same shape.
  905. *
  906. *@par Inputs:
  907. *@li var: A mutable tensor. Should be from a Variable().
  908. *@li accum: A mutable tensor. Has the same type as "var".
  909. * Should be from a Variable().
  910. *@li lr: A scalar. Has the same type as "var".
  911. *@li grad: A tensor for the gradient. Has the same type as "var".
  912. *
  913. *@par Attributes:
  914. * use_locking: An optional bool. Defaults to "False".
  915. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  916. * by a lock; otherwise the behavior is undefined, but may exhibit less
  917. * contention.
  918. *
  919. *@par Outputs:
  920. *@li var: A mutable tensor. Has the same type as input "var".
  921. *@li accum: A mutable tensor. Has the same type as input "var".
  922. *
  923. *
  924. */
  925. REG_OP(ApplyAdagradD)
  926. .INPUT(var, TensorType::NumberType())
  927. .INPUT(accum, TensorType::NumberType())
  928. .INPUT(lr, TensorType::NumberType())
  929. .INPUT(grad, TensorType::NumberType())
  930. .OUTPUT(var, TensorType::NumberType())
  931. .OUTPUT(accum, TensorType::NumberType())
  932. .ATTR(update_slots, Bool, true)
  933. .ATTR(use_locking, Bool, false)
  934. .OP_END_FACTORY_REG(ApplyAdagradD)
  935. /**
  936. * @brief Updates "var" according to the adagradv2 scheme.
  937. * accum += grad * grad \n
  938. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  939. *
  940. * @par Inputs:
  941. * @li var: A mutable tensor. Must be one of the data types defined in
  942. * TensorType::NumberType(). Should be from a Variable().
  943. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  944. * Variable().
  945. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  946. * from a Variable().
  947. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  948. * from a Variable().
  949. * @li epsilon: A scalar. Has the same type as "var".
  950. *
  951. * @par Attributes:
  952. * @li update_slots: An optional bool. Defaults to "True".
  953. * If "True", "accum" will be updated
  954. * @li use_locking: An optional bool. Defaults to "False".
  955. * If "True", updating of the "var" tensor is protected by a lock;
  956. * otherwise the behavior is undefined, but may exhibit less contention.
  957. *
  958. * @par Outputs:
  959. * var: A mutable tensor. Has the same type as input "var".
  960. *
  961. * @attention Constraints:
  962. * The input tensors must have the same shape.
  963. *
  964. */
  965. REG_OP(ApplyAdagradV2)
  966. .INPUT(var, TensorType::NumberType())
  967. .INPUT(accum, TensorType::NumberType())
  968. .INPUT(lr, TensorType::NumberType())
  969. .INPUT(epsilon, TensorType::NumberType())
  970. .INPUT(grad, TensorType::NumberType())
  971. .OUTPUT(var, TensorType::NumberType())
  972. .ATTR(update_slots, Bool, true)
  973. .ATTR(use_locking, Bool, false)
  974. .OP_END_FACTORY_REG(ApplyAdagradV2)
  975. /**
  976. * @brief Updates "var" according to the adagradv2 scheme.
  977. * accum += grad * grad \n
  978. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  979. *
  980. * @par Inputs:
  981. * @li var: A mutable tensor. Must be one of the data types defined in
  982. * TensorType::NumberType(). Should be from a Variable().
  983. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  984. * Variable().
  985. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  986. * from a Variable().
  987. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  988. * from a Variable().
  989. *
  990. * @par Attributes:
  991. * @li epsilon: A scalar. Has the same type as "var".
  992. * @li update_slots: An optional bool. Defaults to "True".
  993. * If "True", "accum" will be updated
  994. * @li use_locking: An optional bool. Defaults to "False".
  995. * If "True", updating of the "var" tensor is protected by a lock;
  996. * otherwise the behavior is undefined, but may exhibit less contention.
  997. *
  998. * @par Outputs:
  999. * var: A mutable tensor. Has the same type as input "var".
  1000. *
  1001. * @attention Constraints:
  1002. * The input tensors must have the same shape.
  1003. *
  1004. */
  1005. REG_OP(ApplyAdagradV2D)
  1006. .INPUT(var, TensorType::NumberType())
  1007. .INPUT(accum, TensorType::NumberType())
  1008. .INPUT(lr, TensorType::NumberType())
  1009. .INPUT(grad, TensorType::NumberType())
  1010. .OUTPUT(var, TensorType::NumberType())
  1011. .OUTPUT(accum, TensorType::NumberType())
  1012. .REQUIRED_ATTR(epsilon, Float)
  1013. .ATTR(update_slots, Bool, true)
  1014. .ATTR(use_locking, Bool, false)
  1015. .OP_END_FACTORY_REG(ApplyAdagradV2D)
  1016. /**
  1017. *@brief Updates "var" according to the proximal adagrad scheme.
  1018. *@par Inputs:
  1019. *Eight inputs, including:
  1020. * @li var: A mutable Tensor. Must be one of the following types:
  1021. * TensorType::NumberType(). Should be a Variable Tensor.
  1022. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1023. * type as "var". Should be a Variable Tensor.
  1024. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1025. * Should be a Variable Tensor.
  1026. * @li grad: A Tensor of the same type as "var", for the gradient.
  1027. * @li lr: A Tensor of the same type as "var".
  1028. * Scaling factor. Must be a scalar.
  1029. * @li l1: A Tensor of the same type as "var".
  1030. * L1 regulariation. Must be a scalar.
  1031. * @li l2: A Tensor of the same type as "var".
  1032. * L2 regulariation. Must be a scalar.
  1033. * @li global_step: A Tensor of type int32 or int64.
  1034. * Training step number. Must be a scalar.
  1035. *@par Attributes:
  1036. *use_locking: An optional bool. Defaults to "False".
  1037. * If "True", updating of the var and accum tensors will be
  1038. * protected by a lock; otherwise the behavior is undefined,
  1039. * but may exhibit less contention.
  1040. *@par Outputs:
  1041. *var: A mutable Tensor. Has the same type as "var".
  1042. */
  1043. REG_OP(ApplyAdagradDA)
  1044. .INPUT(var, TensorType::NumberType())
  1045. .INPUT(gradient_accumulator, TensorType::NumberType())
  1046. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1047. .INPUT(grad, TensorType::NumberType())
  1048. .INPUT(lr, TensorType::NumberType())
  1049. .INPUT(l1, TensorType::NumberType())
  1050. .INPUT(l2, TensorType::NumberType())
  1051. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1052. .OUTPUT(var, TensorType::NumberType())
  1053. .ATTR(use_locking, Bool, false)
  1054. .OP_END_FACTORY_REG(ApplyAdagradDA)
  1055. /**
  1056. *@brief Updates "var" according to the proximal adagrad scheme.
  1057. *@par Inputs:
  1058. *Eight inputs, including:
  1059. * @li var: A mutable Tensor. Must be one of the following types:
  1060. * TensorType::NumberType(). Should be a Variable Tensor.
  1061. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1062. * type as "var". Should be a Variable Tensor.
  1063. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1064. * Should be a Variable Tensor.
  1065. * @li grad: A Tensor of the same type as "var", for the gradient.
  1066. * @li lr: A Tensor of the same type as "var".
  1067. * Scaling factor. Must be a scalar.
  1068. * @li l1: A Tensor of the same type as "var".
  1069. * L1 regulariation. Must be a scalar.
  1070. * @li l2: A Tensor of the same type as "var".
  1071. * L2 regulariation. Must be a scalar.
  1072. * @li global_step: A Tensor of type int32 or int64.
  1073. * Training step number. Must be a scalar.
  1074. *@par Attributes:
  1075. *use_locking: An optional bool. Defaults to "False".
  1076. * If "True", updating of the var and accum tensors will be
  1077. * protected by a lock; otherwise the behavior is undefined,
  1078. * but may exhibit less contention.
  1079. *@par Outputs:
  1080. *var: A mutable Tensor. Has the same type as "var".
  1081. *gradient_accumulator: A mutable Tensor. Has the same type as "var".
  1082. *gradient_squared_accumulator: A mutable Tensor. Has the same type as "var".
  1083. */
  1084. REG_OP(ApplyAdagradDAD)
  1085. .INPUT(var, TensorType::NumberType())
  1086. .INPUT(gradient_accumulator, TensorType::NumberType())
  1087. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1088. .INPUT(grad, TensorType::NumberType())
  1089. .INPUT(lr, TensorType::NumberType())
  1090. .INPUT(l1, TensorType::NumberType())
  1091. .INPUT(l2, TensorType::NumberType())
  1092. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1093. .OUTPUT(var, TensorType::NumberType())
  1094. .OUTPUT(gradient_accumulator, TensorType::NumberType())
  1095. .OUTPUT(gradient_squared_accumulator, TensorType::NumberType())
  1096. .ATTR(use_locking, Bool, false)
  1097. .OP_END_FACTORY_REG(ApplyAdagradDAD)
  1098. /**
  1099. *@brief Returns the dimension index in the destination data format given the one in
  1100. * the source data format.
  1101. *
  1102. *@par Inputs:
  1103. * x: A tensor of type int32 or int64.
  1104. * A Tensor with each element as a dimension index in source data format.
  1105. * Must be in the range [-4, 4).
  1106. *
  1107. *@par Attributes:
  1108. *@li src_format: An optional string. Defaults to NHWC.
  1109. * source data format.
  1110. *@li dst_format: An optional string. Defaults to NCHW.
  1111. * destination data format.
  1112. *
  1113. *@par Outputs:
  1114. * y: A tensor. Has the same type as "x".
  1115. *
  1116. */
  1117. REG_OP(DataFormatDimMap)
  1118. .INPUT(x, TensorType::IndexNumberType())
  1119. .ATTR(src_format, String, "NHWC")
  1120. .ATTR(dst_format, String, "NCHW")
  1121. .OUTPUT(y, TensorType::IndexNumberType())
  1122. .OP_END_FACTORY_REG(DataFormatDimMap)
  1123. /**
  1124. * @brief Implements stochastic gradient descent (optionally with momentum).\n
  1125. * Nesterov momentum is based on the formula from
  1126. * On the importance of initialization and momentum in deep learning.\n
  1127. * @par Inputs:
  1128. * @li parameters: A mutable tensor of type float16 or float32.\n
  1129. * Specifies the iterable of parameters to optimize or dicts defining parameter
  1130. * groups.
  1131. * @li gradient: A tensor of type float16 or float32.\n
  1132. * Specifies the gradient of training step.
  1133. * @li learning_rate: A tensor of type float16 or float32.\n
  1134. * Specifies the learing_rate of training step.
  1135. * @li accum: A tensor of type float16 or float32.
  1136. * Specifies the velocity of training step.
  1137. * @li momentum: A tensor of type float16 or float32.
  1138. * Specifies the momentum factor.
  1139. * @li stat: A tensor of type float16 or float32.
  1140. * Specifies the status representing the first step or not.
  1141. * @par Attributes:
  1142. * @li dampening: An optional float, specifying the dampening for momentum.
  1143. * Defaults to "0.0".
  1144. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  1145. * "0.0".
  1146. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  1147. * momentum. Defaults to "False".
  1148. * @par Outputs:
  1149. * parameters: A mutable tensor same as input "parameters".
  1150. * @see ApplyMomentum()
  1151. */
  1152. REG_OP(SGD)
  1153. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1154. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  1155. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  1156. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  1157. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  1158. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  1159. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1160. .ATTR(dampening, Float, 0.0)
  1161. .ATTR(weight_decay, Float, 0.0)
  1162. .ATTR(nesterov, Bool, false)
  1163. .OP_END_FACTORY_REG(SGD)
  1164. /**
  1165. * @brief Updates "var" according to the RMSProp algorithm.\n
  1166. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  1167. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  1168. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  1169. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  1170. * var <- var - mom\n
  1171. *
  1172. * @par Inputs:
  1173. * @li var: A mutable tensor. Must be one of the data types defined in\n
  1174. * TensorType::NumberType(). Should be from a Variable().
  1175. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1176. * Variable().
  1177. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1178. * Variable().
  1179. * @li lr: A scalar. Must have the same type as "var".
  1180. * @li rho: A scalar. Must have the same type as "var".
  1181. * @li momentum: A scalar. Must have the same type as "var".
  1182. * @li epsilon: A scalar. Must have the same type as "var".
  1183. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1184. *
  1185. * @par Attributes:
  1186. * use_locking: An optional "bool". Defaults to "False". If "True", updating of\n
  1187. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the\n
  1188. * behavior is undefined, but may exhibit less contention.
  1189. *
  1190. * @par Outputs:
  1191. * var: A mutable tensor. Has the same type as input "var".
  1192. *
  1193. * @attention Constraints:
  1194. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will \n
  1195. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom" \n
  1196. * will not update in iterations during which "grad" is 0.
  1197. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1198. */
  1199. REG_OP(ApplyRMSProp)
  1200. .INPUT(var, TensorType::NumberType())
  1201. .INPUT(ms, TensorType::NumberType())
  1202. .INPUT(mom, TensorType::NumberType())
  1203. .INPUT(lr, TensorType::NumberType())
  1204. .INPUT(rho, TensorType::NumberType())
  1205. .INPUT(momentum, TensorType::NumberType())
  1206. .INPUT(epsilon, TensorType::NumberType())
  1207. .INPUT(grad, TensorType::NumberType())
  1208. .OUTPUT(var, TensorType::NumberType())
  1209. .ATTR(use_locking, Bool, false)
  1210. .OP_END_FACTORY_REG(ApplyRMSProp)
  1211. /**
  1212. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  1213. * considered as an attribute.\n
  1214. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  1215. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  1216. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  1217. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  1218. * var <- var - mom
  1219. *
  1220. * @par Inputs:
  1221. * @li var: A mutable tensor. Must be one of the data types defined in\n
  1222. * TensorType::NumberType(). Should be from a Variable().
  1223. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1224. * Variable().
  1225. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1226. * Variable().
  1227. * @li lr: A scalar. Must have the same type as "var".
  1228. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1229. *
  1230. * @par Attributes:
  1231. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating\n
  1232. * of the "var", "ms", and "mom" tensors will be protected by a lock; \n
  1233. * otherwise the behavior is undefined, but may exhibit less contention.
  1234. * @li rho: A required scalar. Must have the same type as "var".
  1235. * @li momentum: A required scalar. Must have the same type as "var".
  1236. * @li epsilon: A required scalar. Must have the same type as "var".
  1237. *
  1238. * @par Outputs:
  1239. * var: A mutable tensor. Must have the same type as input "var".
  1240. *
  1241. * @attention Constraints:
  1242. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will\n
  1243. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"\n
  1244. * will not update in iterations during which "grad" is 0.
  1245. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1246. */
  1247. REG_OP(ApplyRMSPropD)
  1248. .INPUT(var, TensorType::NumberType())
  1249. .INPUT(ms, TensorType::NumberType())
  1250. .INPUT(mom, TensorType::NumberType())
  1251. .INPUT(lr, TensorType::NumberType())
  1252. .INPUT(grad, TensorType::NumberType())
  1253. .OUTPUT(var, TensorType::NumberType())
  1254. .OUTPUT(ms, TensorType::NumberType())
  1255. .OUTPUT(mom, TensorType::NumberType())
  1256. .REQUIRED_ATTR(rho, Float)
  1257. .REQUIRED_ATTR(momentum, Float)
  1258. .REQUIRED_ATTR(epsilon, Float)
  1259. .ATTR(use_locking, Bool, false)
  1260. .OP_END_FACTORY_REG(ApplyRMSPropD)
  1261. /**
  1262. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  1263. *@par Inputs:
  1264. *Six inputs, including:
  1265. * @li var: A mutable Tensor of type TensorType::NumberType().
  1266. * Should be from a Variable().
  1267. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1268. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1269. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1270. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1271. * @li grad: A Tensor of the same type as "var", for the gradient.
  1272. *@par Attributes:
  1273. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  1274. *@par Outputs:
  1275. * @li var: A mutable tensor. Must have the same type as input "var".
  1276. * @li ms: A mutable tensor. Must have the same type as input "ms".
  1277. * @li mom: A mutable tensor. Must have the same type as input "mom".
  1278. */
  1279. REG_OP(ApplyProximalAdagrad)
  1280. .INPUT(var, TensorType::NumberType())
  1281. .INPUT(accum, TensorType::NumberType())
  1282. .INPUT(lr, TensorType::NumberType())
  1283. .INPUT(l1, TensorType::NumberType())
  1284. .INPUT(l2, TensorType::NumberType())
  1285. .INPUT(grad, TensorType::NumberType())
  1286. .OUTPUT(var, TensorType::NumberType())
  1287. .ATTR(use_locking, Bool, false)
  1288. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  1289. /**
  1290. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  1291. *@par Inputs:
  1292. *Six inputs, including:
  1293. * @li var: A mutable Tensor of type TensorType::NumberType().
  1294. * Should be from a Variable().
  1295. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1296. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1297. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1298. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1299. * @li grad: A Tensor of the same type as "var", for the gradient.
  1300. *@par Attributes:
  1301. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  1302. *@par Outputs:
  1303. * @li var: A mutable Tensor. Has the same type as "var".
  1304. * @li accum: A mutable Tensor. Has the same type as "var".
  1305. */
  1306. REG_OP(ApplyProximalAdagradD)
  1307. .INPUT(var, TensorType::NumberType())
  1308. .INPUT(accum, TensorType::NumberType())
  1309. .INPUT(lr, TensorType::NumberType())
  1310. .INPUT(l1, TensorType::NumberType())
  1311. .INPUT(l2, TensorType::NumberType())
  1312. .INPUT(grad, TensorType::NumberType())
  1313. .OUTPUT(var, TensorType::NumberType())
  1314. .OUTPUT(accum, TensorType::NumberType())
  1315. .ATTR(use_locking, Bool, false)
  1316. .OP_END_FACTORY_REG(ApplyProximalAdagradD)
  1317. /**
  1318. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  1319. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1320. * Only the indices into the first dimensions of "var" and "accum" are updated.
  1321. *@par Inputs:
  1322. * Seven inputs, including:\n
  1323. * @li var: A mutable Tensor.\n
  1324. * TensorType::NumberType(). Should be a Variable Tensor.
  1325. * @li accum: A mutable Tensor of the same type as "var".\n
  1326. * Should be a Variable Tensor.
  1327. * @li lr: A Tensor of the same type as "var".\n
  1328. * Scaling factor. Must be a scalar.
  1329. * @li l1: A Tensor of the same type as "var".\n
  1330. * L1 regulariation. Must be a scalar.
  1331. * @li l2: A Tensor of the same type as "var".\n
  1332. * L2 regulariation. Must be a scalar.
  1333. * @li grad: A Tensor. Has the same type as "var". \n
  1334. * The gradient.
  1335. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  1336. * TensorType::IndexNumberType().
  1337. *@par Attributes:
  1338. *use_locking: An optional bool. Defaults to "False".\n
  1339. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  1340. * If "False", the behavior is undefined, but may exhibit less contention.
  1341. *@par Outputs:
  1342. *var: A mutable Tensor. Has the same type as "var".
  1343. */
  1344. REG_OP(SparseApplyProximalAdagrad)
  1345. .INPUT(var, TensorType::NumberType())
  1346. .INPUT(accum, TensorType::NumberType())
  1347. .INPUT(lr, TensorType::NumberType())
  1348. .INPUT(l1, TensorType::NumberType())
  1349. .INPUT(l2, TensorType::NumberType())
  1350. .INPUT(grad, TensorType::NumberType())
  1351. .INPUT(indices, TensorType::IndexNumberType())
  1352. .OUTPUT(var, TensorType::NumberType())
  1353. .ATTR(use_locking, Bool, false)
  1354. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  1355. /**
  1356. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  1357. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1358. * Only the indices into the first dimensions of "var" and "accum" are updated.
  1359. *@par Inputs:
  1360. * Seven inputs, including:\n
  1361. * @li var: A mutable Tensor.\n
  1362. * TensorType::NumberType(). Should be a Variable Tensor.
  1363. * @li accum: A mutable Tensor of the same type as "var".\n
  1364. * Should be a Variable Tensor.
  1365. * @li lr: A Tensor of the same type as "var".\n
  1366. * Scaling factor. Must be a scalar.
  1367. * @li l1: A Tensor of the same type as "var".\n
  1368. * L1 regulariation. Must be a scalar.
  1369. * @li l2: A Tensor of the same type as "var".\n
  1370. * L2 regulariation. Must be a scalar.
  1371. * @li grad: A Tensor. Has the same type as "var". \n
  1372. * The gradient.
  1373. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  1374. * TensorType::IndexNumberType().
  1375. *@par Attributes:
  1376. *use_locking: An optional bool. Defaults to "False".\n
  1377. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  1378. * If "False", the behavior is undefined, but may exhibit less contention.
  1379. *@par Outputs:
  1380. *@li var: A mutable Tensor. Has the same type as "var".
  1381. *@li accum: A mutable Tensor. Has the same type as "var".
  1382. */
  1383. REG_OP(SparseApplyProximalAdagradD)
  1384. .INPUT(var, TensorType::NumberType())
  1385. .INPUT(accum, TensorType::NumberType())
  1386. .INPUT(lr, TensorType::NumberType())
  1387. .INPUT(l1, TensorType::NumberType())
  1388. .INPUT(l2, TensorType::NumberType())
  1389. .INPUT(grad, TensorType::NumberType())
  1390. .INPUT(indices, TensorType::IndexNumberType())
  1391. .OUTPUT(var, TensorType::NumberType())
  1392. .OUTPUT(accum, TensorType::NumberType())
  1393. .ATTR(use_locking, Bool, false)
  1394. .OP_END_FACTORY_REG(SparseApplyProximalAdagradD)
  1395. /**
  1396. *@brief Updates "var" according to the Ftrl-proximal scheme.
  1397. *@par Inputs:
  1398. *Eight inputs, including:
  1399. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1400. * Should be a Variable Tensor.
  1401. * @li accum: A mutable Tensor of the same type as "var".
  1402. * Should be a Variable Tensor.
  1403. * @li linear: A mutable Tensor of the same type as "var".
  1404. * Should be a Variable Tensor.
  1405. * @li grad: A Tensor of the same type as "var", for the gradient.
  1406. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1407. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1408. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1409. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1410. *@par Attributes:
  1411. *use_locking: An optional bool. Defaults to "False".
  1412. * If "True", updating of the "var" and "accum" tensors will be
  1413. * protected by a lock; otherwise the behavior is undefined,
  1414. * but may exhibit less contention.
  1415. *@par Outputs:
  1416. *var: A mutable Tensor. Has the same type as "var".
  1417. */
  1418. REG_OP(ApplyFtrl)
  1419. .INPUT(var, TensorType::NumberType())
  1420. .INPUT(accum, TensorType::NumberType())
  1421. .INPUT(linear, TensorType::NumberType())
  1422. .INPUT(grad, TensorType::NumberType())
  1423. .INPUT(lr, TensorType::NumberType())
  1424. .INPUT(l1, TensorType::NumberType())
  1425. .INPUT(l2, TensorType::NumberType())
  1426. .INPUT(lr_power, TensorType::NumberType())
  1427. .OUTPUT(var, TensorType::NumberType())
  1428. .ATTR(use_locking, Bool, false)
  1429. .OP_END_FACTORY_REG(ApplyFtrl)
  1430. /**
  1431. *@brief Updates "var" according to the Ftrl-proximal scheme.
  1432. *@par Inputs:
  1433. *Eight inputs, including:
  1434. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1435. * Should be a Variable Tensor.
  1436. * @li accum: A mutable Tensor of the same type as "var".
  1437. * Should be a Variable Tensor.
  1438. * @li linear: A mutable Tensor of the same type as "var".
  1439. * Should be a Variable Tensor.
  1440. * @li grad: A Tensor of the same type as "var", for the gradient.
  1441. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1442. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1443. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1444. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1445. *@par Attributes:
  1446. *use_locking: An optional bool. Defaults to "False".
  1447. * If "True", updating of the "var" and "accum" tensors will be
  1448. * protected by a lock; otherwise the behavior is undefined,
  1449. * but may exhibit less contention.
  1450. *@par Outputs:
  1451. *@li var: A mutable Tensor. Has the same type as "var".
  1452. *@li accum: A mutable Tensor. Has the same type as "accum".
  1453. *@li linear: A mutable Tensor. Has the same type as "linear".
  1454. */
  1455. REG_OP(ApplyFtrlD)
  1456. .INPUT(var, TensorType::NumberType())
  1457. .INPUT(accum, TensorType::NumberType())
  1458. .INPUT(linear, TensorType::NumberType())
  1459. .INPUT(grad, TensorType::NumberType())
  1460. .INPUT(lr, TensorType::NumberType())
  1461. .INPUT(l1, TensorType::NumberType())
  1462. .INPUT(l2, TensorType::NumberType())
  1463. .INPUT(lr_power, TensorType::NumberType())
  1464. .OUTPUT(var, TensorType::NumberType())
  1465. .OUTPUT(accum, TensorType::NumberType())
  1466. .OUTPUT(linear, TensorType::NumberType())
  1467. .ATTR(use_locking, Bool, false)
  1468. .OP_END_FACTORY_REG(ApplyFtrlD)
  1469. /**
  1470. *@brief Update "var" according to the Ftrl-proximal scheme.
  1471. *@par Inputs:
  1472. *Nine inputs, including:
  1473. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1474. * Should be a Variable Tensor.
  1475. * @li accum: A mutable Tensor of the same type as "var".
  1476. * Should be a Variable Tensor.
  1477. * @li linear: A mutable Tensor of the same type as "var".
  1478. * Should be a Variable Tensor.
  1479. * @li grad: A Tensor of the same type as "var", for the gradient.
  1480. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1481. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1482. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1483. * @li l2_shrinkage: A Tensor of the same type as "var".
  1484. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1485. *@par Attributes:
  1486. *use_locking: An optional bool. Defaults to "False".
  1487. * If "True", updating of the "var" and "accum" tensors will be
  1488. * protected by a lock; otherwise the behavior is undefined,
  1489. * but may exhibit less contention.
  1490. *@par Outputs:
  1491. *var: A mutable Tensor. Has the same type as "var".
  1492. */
  1493. REG_OP(ApplyFtrlV2)
  1494. .INPUT(var, TensorType::NumberType())
  1495. .INPUT(accum, TensorType::NumberType())
  1496. .INPUT(linear, TensorType::NumberType())
  1497. .INPUT(grad, TensorType::NumberType())
  1498. .INPUT(lr, TensorType::NumberType())
  1499. .INPUT(l1, TensorType::NumberType())
  1500. .INPUT(l2, TensorType::NumberType())
  1501. .INPUT(l2_shrinkage, TensorType::NumberType())
  1502. .INPUT(lr_power, TensorType::NumberType())
  1503. .OUTPUT(var, TensorType::NumberType())
  1504. .ATTR(use_locking, Bool, false)
  1505. .OP_END_FACTORY_REG(ApplyFtrlV2)
  1506. /**
  1507. *@brief Update "var" according to the Ftrl-proximal scheme.
  1508. *@par Inputs:
  1509. *Nine inputs, including:
  1510. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1511. * Should be a Variable Tensor.
  1512. * @li accum: A mutable Tensor of the same type as "var".
  1513. * Should be a Variable Tensor.
  1514. * @li linear: A mutable Tensor of the same type as "var".
  1515. * Should be a Variable Tensor.
  1516. * @li grad: A Tensor of the same type as "var", for the gradient.
  1517. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1518. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1519. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1520. * @li l2_shrinkage: A Tensor of the same type as "var".
  1521. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1522. *@par Attributes:
  1523. *use_locking: An optional bool. Defaults to "False".
  1524. * If "True", updating of the "var" and "accum" tensors will be
  1525. * protected by a lock; otherwise the behavior is undefined,
  1526. * but may exhibit less contention.
  1527. *@par Outputs:
  1528. *var: A mutable Tensor. Has the same type as "var".
  1529. *accum: A mutable Tensor. Has the same type as "accum".
  1530. *linear: A mutable Tensor. Has the same type as "linear".
  1531. */
  1532. REG_OP(ApplyFtrlV2D)
  1533. .INPUT(var, TensorType::NumberType())
  1534. .INPUT(accum, TensorType::NumberType())
  1535. .INPUT(linear, TensorType::NumberType())
  1536. .INPUT(grad, TensorType::NumberType())
  1537. .INPUT(lr, TensorType::NumberType())
  1538. .INPUT(l1, TensorType::NumberType())
  1539. .INPUT(l2, TensorType::NumberType())
  1540. .INPUT(l2_shrinkage, TensorType::NumberType())
  1541. .INPUT(lr_power, TensorType::NumberType())
  1542. .OUTPUT(var, TensorType::NumberType())
  1543. .OUTPUT(accum, TensorType::NumberType())
  1544. .OUTPUT(linear, TensorType::NumberType())
  1545. .ATTR(use_locking, Bool, false)
  1546. .OP_END_FACTORY_REG(ApplyFtrlV2D)
  1547. /**
  1548. *@brief Updates "var" according to the Adam algorithm.\n
  1549. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  1550. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  1551. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  1552. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1553. *
  1554. *@attention Constraints:\n
  1555. * *The input tensors must have the same shape.*
  1556. *
  1557. *@par Inputs:
  1558. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1559. * Should be from a Variable().
  1560. *@li m: A mutable Tensor of the same type as "var".
  1561. * Should be from a Variable().
  1562. *@li v: A mutable Tensor of the same type as "var".
  1563. * Should be from a Variable().
  1564. *@li beta1_power: A scalar of the same type as "var".
  1565. *@li beta2_power: A scalar of the same type as "var".
  1566. *@li lr: learning_rate. A scalar of the same type as "var".
  1567. *@li beta1: A scalar of the same type as "var".
  1568. *@li beta2: A scalar of the same type as "var".
  1569. *@li epsilon: A scalar of the same type as "var".
  1570. *@li grad: A Tensor of the same type as "var", for the gradient.
  1571. *
  1572. *@par Attributes:\n
  1573. *@li use_locking: An optional bool. Defaults to "False".
  1574. * If "True", updating of the "var", m", and "v" tensors will be protected
  1575. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1576. * contention.
  1577. *@li use_nesterov: An optional bool. Defaults to "False".
  1578. If "True", uses the nesterov update.
  1579. *
  1580. *@par Outputs:
  1581. * var: A mutable Tensor. Has the same type as intput "var".
  1582. */
  1583. REG_OP(ApplyAdam)
  1584. .INPUT(var, TensorType::NumberType())
  1585. .INPUT(m, TensorType::NumberType())
  1586. .INPUT(v, TensorType::NumberType())
  1587. .INPUT(beta1_power, TensorType::NumberType())
  1588. .INPUT(beta2_power, TensorType::NumberType())
  1589. .INPUT(lr, TensorType::NumberType())
  1590. .INPUT(beta1, TensorType::NumberType())
  1591. .INPUT(beta2, TensorType::NumberType())
  1592. .INPUT(epsilon, TensorType::NumberType())
  1593. .INPUT(grad, TensorType::NumberType())
  1594. .OUTPUT(var, TensorType::NumberType())
  1595. .ATTR(use_locking, Bool, false)
  1596. .ATTR(use_nesterov, Bool, false)
  1597. .OP_END_FACTORY_REG(ApplyAdam)
  1598. /**
  1599. *@brief Updates "var" according to the Adam algorithm.\n
  1600. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  1601. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  1602. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  1603. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1604. *
  1605. *@attention Constraints:\n
  1606. * *The input tensors must have the same shape.*
  1607. *
  1608. *@par Inputs:
  1609. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1610. * Should be from a Variable().
  1611. *@li m: A mutable Tensor of the same type as "var".
  1612. * Should be from a Variable().
  1613. *@li v: A mutable Tensor of the same type as "var".
  1614. * Should be from a Variable().
  1615. *@li beta1_power: A scalar of the same type as "var".
  1616. *@li beta2_power: A scalar of the same type as "var".
  1617. *@li lr: learning_rate. A scalar of the same type as "var".
  1618. *@li beta1: A scalar of the same type as "var".
  1619. *@li beta2: A scalar of the same type as "var".
  1620. *@li epsilon: A scalar of the same type as "var".
  1621. *@li grad: A Tensor of the same type as "var", for the gradient.
  1622. *
  1623. *@par Attributes:\n
  1624. *@li use_locking: An optional bool. Defaults to "False".
  1625. * If "True", updating of the "var", m", and "v" tensors will be protected
  1626. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1627. * contention.
  1628. *@li use_nesterov: An optional bool. Defaults to "False".
  1629. If "True", uses the nesterov update.
  1630. *
  1631. *@par Outputs:
  1632. *@li var: A mutable tensor. Has the same type as input "var".
  1633. *@li m: A mutable tensor. Has the same type as input "m".
  1634. *@li v: A mutable tensor. Has the same type as input "v".
  1635. */
  1636. REG_OP(ApplyAdamD)
  1637. .INPUT(var, TensorType::NumberType())
  1638. .INPUT(m, TensorType::NumberType())
  1639. .INPUT(v, TensorType::NumberType())
  1640. .INPUT(beta1_power, TensorType::NumberType())
  1641. .INPUT(beta2_power, TensorType::NumberType())
  1642. .INPUT(lr, TensorType::NumberType())
  1643. .INPUT(beta1, TensorType::NumberType())
  1644. .INPUT(beta2, TensorType::NumberType())
  1645. .INPUT(epsilon, TensorType::NumberType())
  1646. .INPUT(grad, TensorType::NumberType())
  1647. .OUTPUT(var, TensorType::NumberType())
  1648. .OUTPUT(m, TensorType::NumberType())
  1649. .OUTPUT(v, TensorType::NumberType())
  1650. .ATTR(use_locking, Bool, false)
  1651. .ATTR(use_nesterov, Bool, false)
  1652. .OP_END_FACTORY_REG(ApplyAdamD)
  1653. /**
  1654. *@brief Updates "var" according to the proximal adadelta scheme.
  1655. *@par Inputs:
  1656. *Seven inputs, including:
  1657. * @li var: A mutable Tensor of type TensorType::NumberType().
  1658. * Should be a Variable Tensor.
  1659. * @li accum: A mutable Tensor of the same type as "var".
  1660. * Should be a Variable Tensor.
  1661. * @li accum_update: A mutable Tensor of the same type as "var".
  1662. * Should be a Variable Tensor.
  1663. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1664. * @li rho: A scalar of the same type as "var", for the decay factor.
  1665. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1666. * @li grad: A Tensor of the same type as "var", for the gradient.
  1667. *@par Attributes:
  1668. *use_locking: An optional bool. Defaults to "False".
  1669. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1670. * protected by a lock; otherwise the behavior is undefined,
  1671. * but may exhibit less contention.
  1672. *@par Outputs:
  1673. *var: A mutable Tensor. Has the same type as "var".
  1674. */
  1675. REG_OP(ApplyAdadelta)
  1676. .INPUT(var, TensorType::NumberType())
  1677. .INPUT(accum, TensorType::NumberType())
  1678. .INPUT(accum_update, TensorType::NumberType())
  1679. .INPUT(lr, TensorType::NumberType())
  1680. .INPUT(rho, TensorType::NumberType())
  1681. .INPUT(epsilon, TensorType::NumberType())
  1682. .INPUT(grad, TensorType::NumberType())
  1683. .OUTPUT(var, TensorType::NumberType())
  1684. .ATTR(use_locking, Bool, false)
  1685. .OP_END_FACTORY_REG(ApplyAdadelta)
  1686. /**
  1687. *@brief Updates "var" according to the proximal adadelta scheme.
  1688. *@par Inputs:
  1689. *Seven inputs, including:
  1690. * @li var: A mutable Tensor of type TensorType::NumberType().
  1691. * Should be a Variable Tensor.
  1692. * @li accum: A mutable Tensor of the same type as "var".
  1693. * Should be a Variable Tensor.
  1694. * @li accum_update: A mutable Tensor of the same type as "var".
  1695. * Should be a Variable Tensor.
  1696. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1697. * @li rho: A scalar of the same type as "var", for the decay factor.
  1698. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1699. * @li grad: A Tensor of the same type as "var", for the gradient.
  1700. *@par Attributes:
  1701. *use_locking: An optional bool. Defaults to "False".
  1702. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1703. * protected by a lock; otherwise the behavior is undefined,
  1704. * but may exhibit less contention.
  1705. *@par Outputs:
  1706. *@li var: A mutable Tensor. Has the same type as "var".
  1707. *@li accum: A mutable Tensor. Has the same type as "var".
  1708. *@li accum_update: A mutable Tensor. Has the same type as "var".
  1709. */
  1710. REG_OP(ApplyAdadeltaD)
  1711. .INPUT(var, TensorType::NumberType())
  1712. .INPUT(accum, TensorType::NumberType())
  1713. .INPUT(accum_update, TensorType::NumberType())
  1714. .INPUT(lr, TensorType::NumberType())
  1715. .INPUT(rho, TensorType::NumberType())
  1716. .INPUT(epsilon, TensorType::NumberType())
  1717. .INPUT(grad, TensorType::NumberType())
  1718. .OUTPUT(var, TensorType::NumberType())
  1719. .OUTPUT(accum, TensorType::NumberType())
  1720. .OUTPUT(accum_update, TensorType::NumberType())
  1721. .ATTR(use_locking, Bool, false)
  1722. .OP_END_FACTORY_REG(ApplyAdadeltaD)
  1723. /**
  1724. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1725. * accum = accum * momentum + x1 * x2 \n
  1726. * if use_nesterov is True: \n
  1727. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1728. * else:\n
  1729. * var -= accum * lr
  1730. *
  1731. * @par Inputs:
  1732. * Six inputs, including:
  1733. * @li var: A mutable Tensor has type TensorType::NumberType().
  1734. * Should be a Variable Tensor.
  1735. * @li accum: A mutable Tensor has the same type as "var".
  1736. * Should be a Variable Tensor.
  1737. * @li lr: A scalar has the same type as "var", for the scaling factor.
  1738. * @li x1: A Tensor has type TensorType::NumberType().
  1739. * @li momentum: A scalar has the same type as "var".
  1740. * @li x2: A scalar has the same type as "var".
  1741. *
  1742. * @par Attributes:
  1743. * Two attributes, including:
  1744. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1745. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1746. * so in the end, the var you get is actually var - lr * momentum * accum.
  1747. * @li use_locking: An optional bool. Defaults to "False". \n
  1748. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1749. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1750. *
  1751. * @par Outputs:
  1752. * Two outputs, including:
  1753. * @li var: A mutable Tensor has the same type as "var".
  1754. * @li accum: A mutable Tensor has the same type as "var".
  1755. */
  1756. REG_OP(FusedMulApplyMomentum)
  1757. .INPUT(var, TensorType::NumberType())
  1758. .INPUT(accum, TensorType::NumberType())
  1759. .INPUT(lr, TensorType::NumberType())
  1760. .INPUT(x1, TensorType::NumberType())
  1761. .INPUT(momentum, TensorType::NumberType())
  1762. .INPUT(x2, TensorType::NumberType())
  1763. .OUTPUT(var, TensorType::NumberType())
  1764. .OUTPUT(accum, TensorType::NumberType())
  1765. .ATTR(use_nesterov, Bool, false)
  1766. .ATTR(use_locking, Bool, false)
  1767. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  1768. /**
  1769. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1770. * accum = accum * momentum + x1 * x2 \n
  1771. * if use_nesterov is True: \n
  1772. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1773. * else: \n
  1774. * var -= accum * lr
  1775. *
  1776. * @par Inputs:
  1777. * Seven inputs, including:
  1778. * @li var: A mutable Tensor of type float32.
  1779. * Should be a Variable Tensor.
  1780. * @li accum: A mutable Tensor has type TensorType::NumberType().
  1781. * Should be a Variable Tensor.
  1782. * @li lr: A scalar has the same type as "accum", for the scaling factor.
  1783. * @li x1: A Tensor has the same type as "accum".
  1784. * @li momentum: A scalar has the same type as "accum".
  1785. * @li x2: A scalar has the same type as "accum".
  1786. * @li var_copy: A Tensor has type float16.
  1787. *
  1788. * @par Attributes:
  1789. * Two Attributes, including:
  1790. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1791. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1792. * so in the end, the var you get is actually var - lr * momentum * accum.
  1793. * @li use_locking: An optional bool. Defaults to "False". \n
  1794. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1795. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1796. *
  1797. * @par Outputs:
  1798. * Three outputs, including:
  1799. * @li var: A Tensor has the type float32.
  1800. * @li var_copy: A Tensor has the type float16.
  1801. * @li accum: A Tensor has the same type as input "accum".
  1802. */
  1803. REG_OP(FusedMulApplyMomentumExtern)
  1804. .INPUT(var, TensorType(DT_FLOAT))
  1805. .INPUT(accum, TensorType::NumberType())
  1806. .INPUT(lr, TensorType::NumberType())
  1807. .INPUT(x1, TensorType::NumberType())
  1808. .INPUT(momentum, TensorType::NumberType())
  1809. .INPUT(x2, TensorType::NumberType())
  1810. .INPUT(var_copy, TensorType(DT_FLOAT16))
  1811. .OUTPUT(var, TensorType(DT_FLOAT))
  1812. .OUTPUT(var_copy, TensorType(DT_FLOAT16))
  1813. .OUTPUT(accum, TensorType::NumberType())
  1814. .ATTR(use_nesterov, Bool, false)
  1815. .ATTR(use_locking, Bool, false)
  1816. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  1817. /**
  1818. *@brief Update "g" according to the LARS algorithm.
  1819. *@par Inputs:
  1820. *Four inputs, including:
  1821. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1822. * @li g: A Tensor of the same type and shape as "w".
  1823. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1824. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1825. *@par Attributes:
  1826. *Three Attributes, including:
  1827. * @li hyperpara: An optional float. Default value is 0.001.
  1828. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1829. * @li use_clip: An optional bool. Defaults to "False".\n
  1830. * If "True", updating learning rate.
  1831. *@par Outputs:
  1832. *g_new: Tensor of the same type as "w".
  1833. */
  1834. REG_OP(LarsV2)
  1835. .INPUT(w, TensorType(DT_FLOAT))
  1836. .INPUT(g, TensorType(DT_FLOAT))
  1837. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1838. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1839. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1840. .ATTR(hyperpara, Float, 0.001)
  1841. .ATTR(epsilon, Float, 0.00001)
  1842. .ATTR(use_clip, Bool, false)
  1843. .OP_END_FACTORY_REG(LarsV2)
  1844. /**
  1845. *@brief Update "g" according to the LARS algorithm.
  1846. *@par Inputs:
  1847. *Six inputs, including:
  1848. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1849. * @li g: A Tensor of the same type and shape as "w".
  1850. * @li w_square_sum: A Tensor of square_sum(w), has the same type as "w", Must be a scalar.
  1851. * @li g_square_sum: A Tensor of square(g), has the same type as "w", Must be a scalar.
  1852. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1853. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1854. *@par Attributes:
  1855. *Three Attributes, including:
  1856. * @li hyperpara: An optional float. Default value is 0.001.
  1857. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1858. * @li use_clip: An optional bool. Defaults to "False".\n
  1859. * If "True", updating learning rate.
  1860. *@par Outputs:
  1861. *g_new: Tensor of the same type as "w".
  1862. */
  1863. REG_OP(LarsV2Update)
  1864. .INPUT(w, TensorType(DT_FLOAT))
  1865. .INPUT(g, TensorType(DT_FLOAT))
  1866. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  1867. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  1868. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1869. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1870. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1871. .ATTR(hyperpara, Float, 0.001)
  1872. .ATTR(epsilon, Float, 0.00001)
  1873. .ATTR(use_clip, Bool, false)
  1874. .OP_END_FACTORY_REG(LarsV2Update)
  1875. /**
  1876. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1877. * @par Inputs:
  1878. * Nine inputs, including:
  1879. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1880. * Should be a Variable Tensor.
  1881. * @li accum: A mutable Tensor of the same type as "var".
  1882. * Should be a Variable Tensor.
  1883. * @li linear: A mutable Tensor of the same type as "var".
  1884. * Should be a Variable Tensor.
  1885. * @li grad: A Tensor of the same type as "var", for the gradient.
  1886. * @li indices: A vector of indices into the first dimension of var and accum.
  1887. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1888. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1889. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1890. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1891. * @par Attributes:
  1892. * use_locking: An optional bool. Defaults to "False".
  1893. * If "True", updating of the "var" and "accum" tensors will be
  1894. * protected by a lock; otherwise the behavior is undefined,
  1895. * but may exhibit less contention.
  1896. * @par Outputs:
  1897. * var: A Tensor. Has the same type and format as input "var".
  1898. */
  1899. REG_OP(SparseApplyFtrl)
  1900. .INPUT(var, TensorType({DT_FLOAT}))
  1901. .INPUT(accum, TensorType({DT_FLOAT}))
  1902. .INPUT(linear, TensorType({DT_FLOAT}))
  1903. .INPUT(grad, TensorType({DT_FLOAT}))
  1904. .INPUT(indices, TensorType({DT_INT32}))
  1905. .INPUT(lr, TensorType({DT_FLOAT}))
  1906. .INPUT(l1, TensorType({DT_FLOAT}))
  1907. .INPUT(l2, TensorType({DT_FLOAT}))
  1908. .INPUT(lr_power, TensorType({DT_FLOAT}))
  1909. .OUTPUT(var, TensorType({DT_FLOAT}))
  1910. .ATTR(use_locking, Bool, false)
  1911. .OP_END_FACTORY_REG(SparseApplyFtrl)
  1912. /**
  1913. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1914. * @par Inputs:
  1915. * Five inputs, including:
  1916. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1917. * Should be a Variable Tensor.
  1918. * @li accum: A mutable Tensor of the same type as "var".
  1919. * Should be a Variable Tensor.
  1920. * @li linear: A mutable Tensor of the same type as "var".
  1921. * Should be a Variable Tensor.
  1922. * @li grad: A Tensor of the same type as "var", for the gradient.
  1923. * @li indices: A vector of indices into the first dimension of var and accum.
  1924. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1925. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1926. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1927. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1928. * @par Attributes:
  1929. * use_locking: An optional bool. Defaults to "False".
  1930. * If "True", updating of the "var" and "accum" tensors will be
  1931. * protected by a lock; otherwise the behavior is undefined,
  1932. * but may exhibit less contention.
  1933. * @par Outputs:
  1934. * @li var: A Tensor. Has the same type and format as input "var".
  1935. * @li accum: A Tensor. Has the same type and format as input "accum".
  1936. * @li linear: A Tensor. Has the same type and format as input "linear".
  1937. */
  1938. REG_OP(SparseApplyFtrlD)
  1939. .INPUT(var, TensorType({DT_FLOAT}))
  1940. .INPUT(accum, TensorType({DT_FLOAT}))
  1941. .INPUT(linear, TensorType({DT_FLOAT}))
  1942. .INPUT(grad, TensorType({DT_FLOAT}))
  1943. .INPUT(indices, TensorType({DT_INT32}))
  1944. .OUTPUT(var, TensorType({DT_FLOAT}))
  1945. .OUTPUT(accum, TensorType({DT_FLOAT}))
  1946. .OUTPUT(linear, TensorType({DT_FLOAT}))
  1947. .REQUIRED_ATTR(lr, Float)
  1948. .REQUIRED_ATTR(l1, Float)
  1949. .REQUIRED_ATTR(l2, Float)
  1950. .REQUIRED_ATTR(lr_power, Float)
  1951. .ATTR(use_locking, Bool, false)
  1952. .OP_END_FACTORY_REG(SparseApplyFtrlD)
  1953. /**
  1954. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  1955. * That is for rows we have grad for, "var", "accum" and "linear" are updated.
  1956. * @par Inputs:
  1957. * Ten inputs, including:
  1958. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1959. * Should be a Variable Tensor.
  1960. * @li accum: A mutable Tensor of the same type as "var".
  1961. * Should be a Variable Tensor.
  1962. * @li linear: A mutable Tensor of the same type as "var".
  1963. * Should be a Variable Tensor.
  1964. * @li grad: A Tensor of the same type as "var", for the gradient.
  1965. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  1966. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1967. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1968. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1969. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  1970. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1971. * @par Attributes:
  1972. * use_locking: An optional bool. Defaults to "False".
  1973. * If "True", updating of the "var" and "accum" tensors will be
  1974. * protected by a lock; otherwise the behavior is undefined,
  1975. * but may exhibit less contention.
  1976. * @par Outputs:
  1977. * var: A Tensor. Has the same type and format as input "var".
  1978. */
  1979. REG_OP(SparseApplyFtrlV2)
  1980. .INPUT(var, TensorType({DT_FLOAT}))
  1981. .INPUT(accum, TensorType({DT_FLOAT}))
  1982. .INPUT(linear, TensorType({DT_FLOAT}))
  1983. .INPUT(grad, TensorType({DT_FLOAT}))
  1984. .INPUT(indices, TensorType({DT_INT32}))
  1985. .INPUT(lr, TensorType({DT_FLOAT}))
  1986. .INPUT(l1, TensorType({DT_FLOAT}))
  1987. .INPUT(l2, TensorType({DT_FLOAT}))
  1988. .INPUT(l2_shrinkage, TensorType({DT_FLOAT}))
  1989. .INPUT(lr_power, TensorType({DT_FLOAT}))
  1990. .OUTPUT(var, TensorType({DT_FLOAT}))
  1991. .ATTR(use_locking, Bool, false)
  1992. .OP_END_FACTORY_REG(SparseApplyFtrlV2)
  1993. /**
  1994. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  1995. * That is for rows we have grad for, "var", "accum" and "linear" are updated.
  1996. * @par Inputs:
  1997. * Five inputs, including:
  1998. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1999. * Should be a Variable Tensor.
  2000. * @li accum: A mutable Tensor of the same type as "var".
  2001. * Should be a Variable Tensor.
  2002. * @li linear: A mutable Tensor of the same type as "var".
  2003. * Should be a Variable Tensor.
  2004. * @li grad: A Tensor of the same type as "var", for the gradient.
  2005. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  2006. * @par Attributes:
  2007. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2008. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2009. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2010. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2011. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2012. * @li use_locking: An optional bool. Defaults to "False".
  2013. * If "True", updating of the "var" and "accum" tensors will be
  2014. * protected by a lock; otherwise the behavior is undefined,
  2015. * but may exhibit less contention.
  2016. * @par Outputs:
  2017. * @li var: A Tensor. Has the same type and format as input "var".
  2018. * @li accum: A Tensor. Has the same type and format as input "accum".
  2019. * @li linear: A Tensor. Has the same type and format as input "linear".
  2020. */
  2021. REG_OP(SparseApplyFtrlV2D)
  2022. .INPUT(var, TensorType({DT_FLOAT}))
  2023. .INPUT(accum, TensorType({DT_FLOAT}))
  2024. .INPUT(linear, TensorType({DT_FLOAT}))
  2025. .INPUT(grad, TensorType({DT_FLOAT}))
  2026. .INPUT(indices, TensorType({DT_INT32}))
  2027. .OUTPUT(var, TensorType({DT_FLOAT}))
  2028. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2029. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2030. .REQUIRED_ATTR(lr, Float)
  2031. .REQUIRED_ATTR(l1, Float)
  2032. .REQUIRED_ATTR(l2, Float)
  2033. .REQUIRED_ATTR(l2_shrinkage, Float)
  2034. .REQUIRED_ATTR(lr_power, Float)
  2035. .ATTR(use_locking, Bool, false)
  2036. .OP_END_FACTORY_REG(SparseApplyFtrlV2D)
  2037. /**
  2038. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2039. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  2040. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  2041. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  2042. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  2043. * var <- var - mom\n
  2044. *
  2045. * @par Inputs:
  2046. * @li var: A mutable tensor. Must be one of the data types defined in\n
  2047. * TensorType::NumberType(). Should be from a Variable().
  2048. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2049. * Variable().
  2050. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2051. * Variable().
  2052. * @li lr: A scalar. Must have the same type as "var".
  2053. * @li rho: A scalar. Must have the same type as "var".
  2054. * @li momentum: A scalar. Must have the same type as "var".
  2055. * @li epsilon: A scalar. Must have the same type as "var".
  2056. * @li grad: A tensor, specifying the gradient.
  2057. * @li indices: A vector of indices into the first dimension of "var", "mom" and "ms".
  2058. *
  2059. * @par Attributes:
  2060. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2061. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  2062. * behavior is undefined, but may exhibit less contention.
  2063. *
  2064. * @par Outputs:
  2065. * var: A mutable tensor. Has the same type as input "var".
  2066. *
  2067. * @attention Constraints:
  2068. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2069. * in iterations during which "grad" is 0.
  2070. * @li The input tensors "var", "ms", and "mom" must have the same shape.
  2071. *
  2072. */
  2073. REG_OP(SparseApplyRMSProp)
  2074. .INPUT(var, TensorType::NumberType())
  2075. .INPUT(ms, TensorType::NumberType())
  2076. .INPUT(mom, TensorType::NumberType())
  2077. .INPUT(lr, TensorType::NumberType())
  2078. .INPUT(rho, TensorType::NumberType())
  2079. .INPUT(momentum, TensorType::NumberType())
  2080. .INPUT(epsilon, TensorType::NumberType())
  2081. .INPUT(grad, TensorType::NumberType())
  2082. .INPUT(indices, TensorType::IndexNumberType())
  2083. .OUTPUT(var, TensorType::NumberType())
  2084. .ATTR(use_locking, Bool, false)
  2085. .OP_END_FACTORY_REG(SparseApplyRMSProp)
  2086. /**
  2087. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2088. * a const input will be considered as an attribute.\n
  2089. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  2090. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  2091. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  2092. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  2093. * var <- var - mom
  2094. *
  2095. * @par Inputs:
  2096. * @li var: A mutable tensor. Must be one of the data types defined in
  2097. * TensorType::NumberType(). Should be from a Variable().
  2098. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2099. * Variable().
  2100. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2101. * Variable().
  2102. * @li lr: A scalar. Must have the same type as "var".
  2103. * @li grad: A tensor, specifying the gradient.
  2104. *
  2105. * @par Attributes:
  2106. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2107. * updating of the "var", "ms", and "mom" tensors will be protected by a lock;
  2108. * otherwise the behavior is undefined, but may exhibit less contention.
  2109. * @li rho: A required scalar. Must have the same type as "var".
  2110. * @li momentum: A required scalar. Must have the same type as "var".
  2111. * @li epsilon: A required scalar. Must have the same type as "var".
  2112. *
  2113. * @par Outputs:
  2114. * @li var: A mutable tensor. Must have the same type as input "var".
  2115. * @li ms: A mutable tensor. Must have the same type as input "ms".
  2116. * @li mom: A mutable tensor. Must have the same type as input "mom".
  2117. *
  2118. * @attention Constraints:
  2119. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2120. * in iterations during which "grad" is 0.
  2121. * @li The input tensors "var", "ms" and "mom" must have the same shape.
  2122. */
  2123. REG_OP(SparseApplyRMSPropD)
  2124. .INPUT(var, TensorType::NumberType())
  2125. .INPUT(ms, TensorType::NumberType())
  2126. .INPUT(mom, TensorType::NumberType())
  2127. .INPUT(lr, TensorType::NumberType())
  2128. .INPUT(grad, TensorType::NumberType())
  2129. .INPUT(indices, TensorType::IndexNumberType())
  2130. .OUTPUT(var, TensorType::NumberType())
  2131. .OUTPUT(ms, TensorType::NumberType())
  2132. .OUTPUT(mom, TensorType::NumberType())
  2133. .REQUIRED_ATTR(rho, Float)
  2134. .REQUIRED_ATTR(momentum, Float)
  2135. .REQUIRED_ATTR(epsilon, Float)
  2136. .ATTR(use_locking, Bool, false)
  2137. .OP_END_FACTORY_REG(SparseApplyRMSPropD)
  2138. /**
  2139. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2140. * accum <- rho * accum + (1 - rho) * grad.square()\n
  2141. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n
  2142. * var <- var - update * lr\n
  2143. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n
  2144. *
  2145. * @par Inputs:
  2146. * @li var: A mutable tensor. Must be one of the data types defined in\n
  2147. * TensorType::NumberType(). Should be from a Variable().
  2148. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2149. * Variable().
  2150. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2151. * Variable().
  2152. * @li lr: A scalar. Must have the same type as "var".
  2153. * @li rho: A scalar. Must have the same type as "var".
  2154. * @li epsilon: A scalar. Must have the same type as "var".
  2155. * @li grad: A tensor, specifying the gradient.
  2156. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2157. *
  2158. * @par Attributes:
  2159. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2160. * the "var", "accum", and "accum_update" tensors will be protected by a lock; otherwise the
  2161. * behavior is undefined, but may exhibit less contention.
  2162. *
  2163. * @par Outputs:
  2164. * var: A mutable tensor. Has the same type as input "var".
  2165. *
  2166. * @attention Constraints:
  2167. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2168. * in iterations during which "grad" is 0.
  2169. * @li The input tensors "var", "accum", and "accum_update" must have the same shape.
  2170. *
  2171. */
  2172. REG_OP(SparseApplyAdadelta)
  2173. .INPUT(var, TensorType::NumberType())
  2174. .INPUT(accum, TensorType::NumberType())
  2175. .INPUT(accum_update, TensorType::NumberType())
  2176. .INPUT(lr, TensorType::NumberType())
  2177. .INPUT(rho, TensorType::NumberType())
  2178. .INPUT(epsilon, TensorType::NumberType())
  2179. .INPUT(grad, TensorType::NumberType())
  2180. .INPUT(indices, TensorType::IndexNumberType())
  2181. .OUTPUT(var, TensorType::NumberType())
  2182. .ATTR(use_locking, Bool, false)
  2183. .OP_END_FACTORY_REG(SparseApplyAdadelta)
  2184. /**
  2185. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2186. * a const input will be considered as an attribute.\n
  2187. * accum <- rho * accum + (1 - rho) * grad.square()\n
  2188. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n
  2189. * var <- var - update * lr\n
  2190. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n
  2191. *
  2192. * @par Inputs:
  2193. * @li var: A mutable tensor. Must be one of the data types defined in
  2194. * TensorType::NumberType(). Should be from a Variable().
  2195. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2196. * Variable().
  2197. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2198. * Variable().
  2199. * @li lr: A scalar. Must have the same type as "var".
  2200. * @li rho: A scalar. Must have the same type as "var".
  2201. * @li grad: A tensor, specifying the gradient.
  2202. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2203. *
  2204. * @par Attributes:
  2205. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2206. * updating of the "var", "accum", and "accum_update" tensors will be protected by a lock;
  2207. * otherwise the behavior is undefined, but may exhibit less contention.
  2208. * @li epsilon: A required scalar. Must have the same type as "var".
  2209. *
  2210. * @par Outputs:
  2211. * @li var: A mutable tensor. Must have the same type as input "var".
  2212. * @li accum: A mutable tensor. Must have the same type as input "accum".
  2213. * @li accum_update: A mutable tensor. Must have the same type as input "accum_update".
  2214. *
  2215. * @attention Constraints:
  2216. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2217. * in iterations during which "grad" is 0.
  2218. * @li The input tensors "var", "accum" and "accum_update" must have the same shape.
  2219. */
  2220. REG_OP(SparseApplyAdadeltaD)
  2221. .INPUT(var, TensorType::NumberType())
  2222. .INPUT(accum, TensorType::NumberType())
  2223. .INPUT(accum_update, TensorType::NumberType())
  2224. .INPUT(lr, TensorType::NumberType())
  2225. .INPUT(rho, TensorType::NumberType())
  2226. .INPUT(grad, TensorType::NumberType())
  2227. .INPUT(indices, TensorType::IndexNumberType())
  2228. .OUTPUT(var, TensorType::NumberType())
  2229. .OUTPUT(accum, TensorType::NumberType())
  2230. .OUTPUT(accum_update, TensorType::NumberType())
  2231. .REQUIRED_ATTR(epsilon, Float)
  2232. .ATTR(use_locking, Bool, false)
  2233. .OP_END_FACTORY_REG(SparseApplyAdadeltaD)
  2234. /**
  2235. *@brief Clean memory of workspace list.
  2236. *@par Attributes:
  2237. * @li automic_add_mem_size: sizes of workspaces.
  2238. */
  2239. REG_OP(AtomicAddrClean)
  2240. .ATTR(automic_add_mem_size, ListInt, {})
  2241. .OP_END_FACTORY_REG(AtomicAddrClean)
  2242. } // namespace ge
  2243. #endif // GE_OP_TRAINING_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示