You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 100 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_TRAINING_OPS_H
  17. #define GE_OP_TRAINING_OPS_H
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Updates "var" according to the AdaMax algorithm.
  22. * t-1 mean previous period.
  23. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  24. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  25. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  26. *
  27. *@attention Constraints:
  28. * the input tensors must have the same shape.
  29. *
  30. *@par Inputs:
  31. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  32. * Should be from a Variable().
  33. *@li m: A mutable tensor. Has the same type as "var".
  34. * Should be from a Variable().
  35. *@li v: A mutable tensor. Has the same type as "var".
  36. * Should be from a Variable().
  37. *@li beta1_power: A scalar. Has the same type as "var".
  38. *@li lr: learning_rate. A scalar. Has the same type as "var".
  39. *@li beta1: A scalar. Has the same type as "var".
  40. *@li beta2: A scalar. Has the same type as "var".
  41. *@li epsilon: A scalar. Has the same type as "var".
  42. *@li grad: A tensor for the gradient. Has the same type as "var".
  43. *
  44. *@par Attributes:
  45. * use_locking: An optional bool. Defaults to "False".
  46. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  47. * by a lock; otherwise the behavior is undefined, but may exhibit less
  48. * contention.
  49. *
  50. *@par Outputs:
  51. * var: A mutable tensor. Has the same type as input "var".
  52. *
  53. *@par Third-party framework compatibility
  54. *Compatible with the TensorFlow operator ApplyAdaMax.
  55. *
  56. */
  57. REG_OP(ApplyAdaMax)
  58. .INPUT(var, TensorType::NumberType())
  59. .INPUT(m, TensorType::NumberType())
  60. .INPUT(v, TensorType::NumberType())
  61. .INPUT(beta1_power, TensorType::NumberType())
  62. .INPUT(lr, TensorType::NumberType())
  63. .INPUT(beta1, TensorType::NumberType())
  64. .INPUT(beta2, TensorType::NumberType())
  65. .INPUT(epsilon, TensorType::NumberType())
  66. .INPUT(grad, TensorType::NumberType())
  67. .OUTPUT(var, TensorType::NumberType())
  68. .ATTR(use_locking, Bool, false)
  69. .OP_END_FACTORY_REG(ApplyAdaMax)
  70. /**
  71. *@brief Updates "var" according to the AdaMax algorithm.
  72. * t-1 mean previous period.
  73. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  74. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  75. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  76. *
  77. *@attention Constraints:
  78. * the input tensors must have the same shape.
  79. *
  80. *@par Inputs:
  81. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  82. * Should be from a Variable().
  83. *@li m: A mutable tensor. Has the same type as "var".
  84. * Should be from a Variable().
  85. *@li v: A mutable tensor. Has the same type as "var".
  86. * Should be from a Variable().
  87. *@li beta1_power: A scalar. Has the same type as "var".
  88. *@li lr: learning_rate. A scalar. Has the same type as "var".
  89. *@li beta1: A scalar. Has the same type as "var".
  90. *@li beta2: A scalar. Has the same type as "var".
  91. *@li epsilon: A scalar. Has the same type as "var".
  92. *@li grad: A tensor for the gradient. Has the same type as "var".
  93. *
  94. *@par Attributes:
  95. * use_locking: An optional bool. Defaults to "False".
  96. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  97. * by a lock; otherwise the behavior is undefined, but may exhibit less
  98. * contention.
  99. *
  100. *@par Outputs:
  101. *@li var: A mutable tensor. Has the same type as input "var".
  102. *@li m: A mutable tensor. Has the same type as input "m".
  103. *@li v: A mutable tensor. Has the same type as input "v".
  104. *
  105. *@par Third-party framework compatibility
  106. *Compatible with the TensorFlow operator ApplyAdaMax.
  107. *
  108. */
  109. REG_OP(ApplyAdaMaxD)
  110. .INPUT(var, TensorType::NumberType())
  111. .INPUT(m, TensorType::NumberType())
  112. .INPUT(v, TensorType::NumberType())
  113. .INPUT(beta1_power, TensorType::NumberType())
  114. .INPUT(lr, TensorType::NumberType())
  115. .INPUT(beta1, TensorType::NumberType())
  116. .INPUT(beta2, TensorType::NumberType())
  117. .INPUT(epsilon, TensorType::NumberType())
  118. .INPUT(grad, TensorType::NumberType())
  119. .OUTPUT(var, TensorType::NumberType())
  120. .OUTPUT(m, TensorType::NumberType())
  121. .OUTPUT(v, TensorType::NumberType())
  122. .ATTR(use_locking, Bool, false)
  123. .OP_END_FACTORY_REG(ApplyAdaMaxD)
  124. /**
  125. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  126. *@par Inputs:
  127. * Five inputs, including:
  128. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  129. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  130. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  131. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  132. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  133. *@par Attributes:
  134. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  135. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  136. *@par Outputs:
  137. *var: A Tensor. Has the same type and format as input "var".
  138. *@par Third-party framework compatibility
  139. * Compatible with the TensorFlow operator SparseApplyAdagrad.
  140. */
  141. REG_OP(SparseApplyAdagrad)
  142. .INPUT(var, TensorType({DT_FLOAT}))
  143. .INPUT(accum, TensorType({DT_FLOAT}))
  144. .INPUT(lr, TensorType({DT_FLOAT}))
  145. .INPUT(grad, TensorType({DT_FLOAT}))
  146. .INPUT(indices, TensorType({DT_INT32}))
  147. .OUTPUT(var, TensorType({DT_FLOAT}))
  148. .ATTR(use_locking, Bool, false)
  149. .ATTR(update_slots, Bool, true)
  150. .OP_END_FACTORY_REG(SparseApplyAdagrad)
  151. /**
  152. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  153. *@par Inputs:
  154. * Four inputs, including:
  155. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  156. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  157. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  158. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  159. *@par Attributes:
  160. *@li lr: Required, used for computation.
  161. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  162. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  163. *@par Outputs:
  164. *@li var: A Tensor. Has the same type and format as input "var".
  165. *@li accum: A Tensor. Has the same type and format as input "var".
  166. *@par Third-party framework compatibility
  167. * Compatible with the TensorFlow operator SparseApplyAdagrad.
  168. */
  169. REG_OP(SparseApplyAdagradD)
  170. .INPUT(var, TensorType({DT_FLOAT}))
  171. .INPUT(accum, TensorType({DT_FLOAT}))
  172. .INPUT(grad, TensorType({DT_FLOAT}))
  173. .INPUT(indices, TensorType({DT_INT32}))
  174. .OUTPUT(var, TensorType({DT_FLOAT}))
  175. .OUTPUT(accum, TensorType({DT_FLOAT}))
  176. .REQUIRED_ATTR(lr, Float)
  177. .ATTR(use_locking, Bool, false)
  178. .ATTR(update_slots, Bool, true)
  179. .OP_END_FACTORY_REG(SparseApplyAdagradD)
  180. /**
  181. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  182. *@par Inputs:
  183. *Six inputs, including:
  184. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  185. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  186. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  187. *@li epsilon: An NCHW, NHWC, or ND Tensor of type float32.
  188. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  189. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  190. *@par Attributes:
  191. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  192. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different.
  193. *@par Outputs:
  194. *var: A Tensor. Has the same type and format as input "var".
  195. *@par Third-party framework compatibility
  196. *Compatible with the TensorFlow operator SparseApplyAdagradV2.
  197. */
  198. REG_OP(SparseApplyAdagradV2)
  199. .INPUT(var, TensorType({DT_FLOAT}))
  200. .INPUT(accum, TensorType({DT_FLOAT}))
  201. .INPUT(lr, TensorType({DT_FLOAT}))
  202. .INPUT(epsilon, TensorType({DT_FLOAT}))
  203. .INPUT(grad, TensorType({DT_FLOAT}))
  204. .INPUT(indices, TensorType({DT_INT32}))
  205. .OUTPUT(var, TensorType({DT_FLOAT}))
  206. .ATTR(use_locking, Bool, false)
  207. .ATTR(update_slots, Bool, true)
  208. .OP_END_FACTORY_REG(SparseApplyAdagradV2)
  209. /**
  210. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  211. *@par Inputs:
  212. *Four inputs, including:
  213. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  214. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  215. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  216. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  217. *@par Attributes:
  218. *@li lr: Required, used for computation.
  219. *@li epsilon: Required, used for computation.
  220. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  221. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different.
  222. *@par Outputs:
  223. *@li var: A Tensor. Has the same type and format as input "var".
  224. *@li accum: A Tensor. Has the same type and format as input "accum".
  225. *@par Third-party framework compatibility
  226. *Compatible with the TensorFlow operator SparseApplyAdagradV2.
  227. */
  228. REG_OP(SparseApplyAdagradV2D)
  229. .INPUT(var, TensorType({DT_FLOAT}))
  230. .INPUT(accum, TensorType({DT_FLOAT}))
  231. .INPUT(grad, TensorType({DT_FLOAT}))
  232. .INPUT(indices, TensorType({DT_INT32}))
  233. .OUTPUT(var, TensorType({DT_FLOAT}))
  234. .OUTPUT(accum, TensorType({DT_FLOAT}))
  235. .REQUIRED_ATTR(lr, Float)
  236. .REQUIRED_ATTR(epsilon, Float)
  237. .ATTR(use_locking, Bool, false)
  238. .ATTR(update_slots, Bool, true)
  239. .OP_END_FACTORY_REG(SparseApplyAdagradV2D)
  240. /**
  241. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  242. * want to use Nesterov momentum.
  243. * computing process: \n
  244. * accum = accum * momentum + grad\n
  245. * var -= lr * accum
  246. *
  247. *@attention Constraints:
  248. * the input tensors must have the same shape.
  249. *
  250. *@par Inputs:
  251. *@li var: A mutable tensor. Should be from a Variable().
  252. *@li accum: A mutable tensor. Has the same type as "var".
  253. * Should be from a Variable().
  254. *@li lr: A scalar. Has the same type as "var".
  255. *@li grad: A tensor for the gradient. Has the same type as "var".
  256. *
  257. *@par Attributes:
  258. *@li use_nesterov: An optional bool. Defaults to "False".
  259. * If "True", the tensor passed to compute grad will be
  260. * var - lr * momentum * accum, so in the end, the var you get is actually
  261. * var - lr * momentum * accum.
  262. *
  263. *@li use_locking: An optional bool. Defaults to "False".
  264. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  265. * otherwise the behavior is undefined, but may exhibit less contention.
  266. *
  267. *@par Outputs:
  268. * var: A mutable tensor. Has the same type as input "var".
  269. *
  270. *@par Third-party framework compatibility
  271. *Compatible with the TensorFlow operator ApplyMomentum.
  272. *
  273. */
  274. REG_OP(ApplyMomentum)
  275. .INPUT(var, TensorType::NumberType())
  276. .INPUT(accum, TensorType::NumberType())
  277. .INPUT(lr, TensorType::NumberType())
  278. .INPUT(grad, TensorType::NumberType())
  279. .INPUT(momentum, TensorType::NumberType())
  280. .OUTPUT(var, TensorType::NumberType())
  281. .ATTR(use_nesterov, Bool, false)
  282. .ATTR(use_locking, Bool, false)
  283. .OP_END_FACTORY_REG(ApplyMomentum)
  284. REG_OP(ApplyMomentumCCE)
  285. .INPUT(var, TensorType::NumberType())
  286. .INPUT(accum, TensorType::NumberType())
  287. .INPUT(lr, TensorType::NumberType())
  288. .INPUT(grad, TensorType::NumberType())
  289. .INPUT(momentum, TensorType::NumberType())
  290. .OUTPUT(var, TensorType::NumberType())
  291. .ATTR(use_nesterov, Bool, false)
  292. .ATTR(use_locking, Bool, false)
  293. .OP_END_FACTORY_REG(ApplyMomentumCCE)
  294. /**
  295. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  296. * want to use Nesterov momentum.
  297. * computing process: \n
  298. * accum = accum * momentum + grad\n
  299. * var -= lr * accum
  300. *
  301. *@attention Constraints:
  302. * the input tensors must have the same shape.
  303. *
  304. *@par Inputs:
  305. *@li var: A mutable tensor. Should be from a Variable().
  306. *@li accum: A mutable tensor. Has the same type as "var".
  307. * Should be from a Variable().
  308. *@li lr: A scalar. Has the same type as "var".
  309. *@li grad: A tensor for the gradient. Has the same type as "var".
  310. *
  311. *@par Attributes:
  312. *@li use_nesterov: An optional bool. Defaults to "False".
  313. * If "True", the tensor passed to compute grad will be
  314. * var - lr * momentum * accum, so in the end, the var you get is actually
  315. * var - lr * momentum * accum.
  316. *
  317. *@li use_locking: An optional bool. Defaults to "False".
  318. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  319. * otherwise the behavior is undefined, but may exhibit less contention.
  320. *
  321. *@par Outputs:
  322. * var: A mutable tensor. Has the same type as input "var".
  323. * accum: A mutable tensor. Has the same type as input "accum".
  324. *@par Third-party framework compatibility
  325. *Compatible with the TensorFlow operator ApplyMomentum.
  326. *
  327. */
  328. REG_OP(ApplyMomentumD)
  329. .INPUT(var, TensorType::NumberType())
  330. .INPUT(accum, TensorType::NumberType())
  331. .INPUT(lr, TensorType::NumberType())
  332. .INPUT(grad, TensorType::NumberType())
  333. .INPUT(momentum, TensorType::NumberType())
  334. .OUTPUT(var, TensorType::NumberType())
  335. .OUTPUT(accum, TensorType::NumberType())
  336. .ATTR(use_nesterov, Bool, false)
  337. .ATTR(use_locking, Bool, false)
  338. .OP_END_FACTORY_REG(ApplyMomentumD)
  339. /**
  340. *@brief Updates '*var' according to the momentum scheme.
  341. * accum = accum * momentum - grad * lr \n
  342. * if use_nesterov is True: \n
  343. * var += accum * momentum - grad * lr \n
  344. * else: \n
  345. * var += accum
  346. *
  347. *@par Inputs:
  348. *@li var: A mutable tensor. Must be one of the data types defined in
  349. * TensorType::NumberType(). Should be from a Variable().
  350. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  351. * Variable().
  352. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  353. * from a Variable().
  354. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  355. * from a Variable().
  356. *@li momentum: A scalar. Has the same type as "var".
  357. *
  358. *@par Attributes:
  359. *@li use_nesterov: An optional bool. Defaults to "False".
  360. * If "True", var will be updated by using Nesterov momentum.
  361. *@li use_locking: An optional bool. Defaults to "False".
  362. * If "True", updating of the "var" tensor is protected by a lock;
  363. * otherwise the behavior is undefined, but may exhibit less contention.
  364. *
  365. *@par Outputs:
  366. * var: A mutable tensor. Has the same type as input "var".
  367. *
  368. *@attention Constraints:
  369. * The input tensors must have the same shape.
  370. *
  371. *@par Third-party framework compatibility
  372. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  373. *
  374. */
  375. REG_OP(ApplyKerasMomentum)
  376. .INPUT(var, TensorType::NumberType())
  377. .INPUT(accum, TensorType::NumberType())
  378. .INPUT(lr, TensorType::NumberType())
  379. .INPUT(grad, TensorType::NumberType())
  380. .INPUT(momentum, TensorType::NumberType())
  381. .OUTPUT(var, TensorType::NumberType())
  382. .ATTR(use_locking, Bool, false)
  383. .ATTR(use_nesterov, Bool, false)
  384. .OP_END_FACTORY_REG(ApplyKerasMomentum)
  385. /**
  386. *@brief Updates '*var' according to the momentum scheme.
  387. * accum = accum * momentum - grad * lr \n
  388. * if use_nesterov is True: \n
  389. * var += accum * momentum - grad * lr \n
  390. * else: \n
  391. * var += accum
  392. *
  393. *@par Inputs:
  394. *@li var: A mutable tensor. Must be one of the data types defined in
  395. * TensorType::NumberType(). Should be from a Variable().
  396. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  397. * Variable().
  398. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  399. * from a Variable().
  400. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  401. * from a Variable().
  402. *@li momentum: A scalar. Has the same type as "var". Should be from a
  403. * Variable().
  404. *
  405. *@par Attributes:
  406. *@li use_nesterov: An optional bool. Defaults to "False".
  407. * If "True", var will be updated by using nesterov momentum
  408. *@li use_locking: An optional bool. Defaults to "False".
  409. * If "True", updating of the "var" tensor is protected by a lock;
  410. * otherwise the behavior is undefined, but may exhibit less contention.
  411. *
  412. *@par Outputs:
  413. *@li var: A mutable tensor. Has the same type as input "var".
  414. *@li accum: A mutable tensor. Has the same type as input "var"
  415. *
  416. *@attention Constraints:
  417. * The input tensors must have the same shape.
  418. *
  419. *@par Third-party framework compatibility
  420. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  421. *
  422. */
  423. REG_OP(ApplyKerasMomentumD)
  424. .INPUT(var, TensorType::NumberType())
  425. .INPUT(accum, TensorType::NumberType())
  426. .INPUT(lr, TensorType::NumberType())
  427. .INPUT(grad, TensorType::NumberType())
  428. .INPUT(momentum, TensorType::NumberType())
  429. .OUTPUT(var, TensorType::NumberType())
  430. .OUTPUT(accum, TensorType::NumberType())
  431. .ATTR(use_locking, Bool, false)
  432. .ATTR(use_nesterov, Bool, false)
  433. .OP_END_FACTORY_REG(ApplyKerasMomentumD)
  434. /**
  435. *@brief Updates '*var' according to the Adam algorithm..
  436. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  437. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  438. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  439. * vhat_t := max{vhat_{t-1}, v_t}
  440. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  441. *
  442. *@par Inputs:
  443. *@li var: A mutable tensor. Must be one of the data types defined in
  444. * TensorType::NumberType(). Should be from a Variable().
  445. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  446. * Variable().
  447. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  448. * Variable().
  449. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  450. * Variable().
  451. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  452. * Variable().
  453. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  454. * Variable().
  455. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  456. * from a Variable().
  457. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  458. * from a Variable().
  459. *
  460. *@par Attributes:
  461. *@li beta1: A scalar. Has the same type as "var".
  462. *@li beta2: A scalar. Has the same type as "var".
  463. *@li epsilon: A scalar. Has the same type as "var".
  464. *@li use_locking: An optional bool. Defaults to "False".
  465. * If "True", updating of the "var" tensor is protected by a lock;
  466. * otherwise the behavior is undefined, but may exhibit less contention.
  467. *
  468. *@par Outputs:
  469. *@li var: A mutable tensor. Has the same type as input "var".
  470. *@li m: A mutable tensor. Has the same type as input "var"
  471. *@li v: A mutable tensor. Has the same type as input "var"
  472. *@li vhat: A mutable tensor. Has the same type as input "var"
  473. *
  474. *@attention Constraints:
  475. * The input tensors must have the same shape.
  476. *
  477. *@par Third-party framework compatibility
  478. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  479. *
  480. */
  481. REG_OP(ApplyAdamWithAmsgradD)
  482. .INPUT(var, TensorType::NumberType())
  483. .INPUT(m, TensorType::NumberType())
  484. .INPUT(v, TensorType::NumberType())
  485. .INPUT(vhat, TensorType::NumberType())
  486. .INPUT(beta1_power, TensorType::NumberType())
  487. .INPUT(beta2_power, TensorType::NumberType())
  488. .INPUT(lr, TensorType::NumberType())
  489. .INPUT(grad, TensorType::NumberType())
  490. .OUTPUT(var, TensorType::NumberType())
  491. .OUTPUT(m, TensorType::NumberType())
  492. .OUTPUT(v, TensorType::NumberType())
  493. .OUTPUT(vhat, TensorType::NumberType())
  494. .REQUIRED_ATTR(beta1, Float)
  495. .REQUIRED_ATTR(beta2, Float)
  496. .REQUIRED_ATTR(epsilon, Float)
  497. .ATTR(use_locking, Bool, false)
  498. .OP_END_FACTORY_REG(ApplyAdamWithAmsgradD)
  499. /**
  500. *@brief Updates '*var' according to the Adam algorithm..
  501. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  502. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  503. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  504. * vhat_t := max{vhat_{t-1}, v_t}
  505. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  506. *
  507. *@par Inputs:
  508. *@li var: A mutable tensor. Must be one of the data types defined in
  509. * TensorType::NumberType(). Should be from a Variable().
  510. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  511. * Variable().
  512. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  513. * Variable().
  514. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  515. * Variable().
  516. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  517. * Variable().
  518. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  519. * Variable().
  520. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  521. * from a Variable().
  522. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  523. * from a Variable().
  524. *
  525. *@par Attributes:
  526. *@li beta1: A scalar. Has the same type as "var".
  527. *@li beta2: A scalar. Has the same type as "var".
  528. *@li epsilon: A scalar. Has the same type as "var".
  529. *@li use_locking: An optional bool. Defaults to "False".
  530. * If "True", updating of the "var" tensor is protected by a lock;
  531. * otherwise the behavior is undefined, but may exhibit less contention.
  532. *
  533. *@par Outputs:
  534. *@li var: A mutable tensor. Has the same type as input "var".
  535. *@li m: A mutable tensor. Has the same type as input "var"
  536. *@li v: A mutable tensor. Has the same type as input "var"
  537. *@li vhat: A mutable tensor. Has the same type as input "var"
  538. *
  539. *@attention Constraints:
  540. * The input tensors must have the same shape.
  541. *
  542. *@par Third-party framework compatibility
  543. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  544. *
  545. */
  546. REG_OP(ApplyAdamWithAmsgrad)
  547. .INPUT(var, TensorType::NumberType())
  548. .INPUT(m, TensorType::NumberType())
  549. .INPUT(v, TensorType::NumberType())
  550. .INPUT(vhat, TensorType::NumberType())
  551. .INPUT(beta1_power, TensorType::NumberType())
  552. .INPUT(beta2_power, TensorType::NumberType())
  553. .INPUT(lr, TensorType::NumberType())
  554. .INPUT(beta1, TensorType::NumberType())
  555. .INPUT(beta2, TensorType::NumberType())
  556. .INPUT(epsilon, TensorType::NumberType())
  557. .INPUT(grad, TensorType::NumberType())
  558. .OUTPUT(var, TensorType::NumberType())
  559. .ATTR(use_locking, Bool, false)
  560. .OP_END_FACTORY_REG(ApplyAdamWithAmsgrad)
  561. /**
  562. *@brief Updates "var" according to the AddSign update.
  563. * t-1 mean previous period.
  564. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  565. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  566. * var <- var - lr * update
  567. *
  568. *@attention Constraints:
  569. * the input tensors must have the same shape.
  570. *
  571. *@par Inputs:
  572. *@li var: A mutable tensor. Should be from a Variable().
  573. *@li m: A mutable tensor. Has the same type as "var".
  574. * Should be from a Variable().
  575. *@li lr: A scalar. Has the same type as "var".
  576. *@li logbase: A scalar. Has the same type as "var".
  577. *@li sign_decay: A scalar. Has the same type as "var".
  578. *@li beta: A scalar. Has the same type as "var".
  579. *@li grad: A tensor for the gradient. Has the same type as "var".
  580. *
  581. *@par Attributes:
  582. * use_locking: An optional bool. Defaults to "False".
  583. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  584. * by a lock; otherwise the behavior is undefined, but may exhibit less
  585. * contention.
  586. *
  587. *@par Outputs:
  588. * var: A mutable tensor. Has the same type as input "var".
  589. *
  590. *@par Third-party framework compatibility
  591. *Compatible with the TensorFlow operator ApplyPowerSign.
  592. *
  593. */
  594. REG_OP(ApplyPowerSign)
  595. .INPUT(var, TensorType::NumberType())
  596. .INPUT(m, TensorType::NumberType())
  597. .INPUT(lr, TensorType::NumberType())
  598. .INPUT(logbase, TensorType::NumberType())
  599. .INPUT(sign_decay, TensorType::NumberType())
  600. .INPUT(beta, TensorType::NumberType())
  601. .INPUT(grad, TensorType::NumberType())
  602. .OUTPUT(var, TensorType::NumberType())
  603. .ATTR(use_locking, Bool, false)
  604. .OP_END_FACTORY_REG(ApplyPowerSign)
  605. /**
  606. *@brief Updates "var" according to the AddSign update.
  607. * t-1 mean previous period.
  608. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  609. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  610. * var <- var - lr * update
  611. *
  612. *@attention Constraints:
  613. * the input tensors must have the same shape.
  614. *
  615. *@par Inputs:
  616. *@li var: A mutable tensor. Should be from a Variable().
  617. *@li m: A mutable tensor. Has the same type as "var".
  618. * Should be from a Variable().
  619. *@li lr: A scalar. Has the same type as "var".
  620. *@li logbase: A scalar. Has the same type as "var".
  621. *@li sign_decay: A scalar. Has the same type as "var".
  622. *@li beta: A scalar. Has the same type as "var".
  623. *@li grad: A tensor for the gradient. Has the same type as "var".
  624. *
  625. *@par Attributes:
  626. * use_locking: An optional bool. Defaults to "False".
  627. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  628. * by a lock; otherwise the behavior is undefined, but may exhibit less
  629. * contention.
  630. *
  631. *@par Outputs:
  632. *@li var: A mutable tensor. Has the same type as input "var".
  633. *@li m: A mutable tensor. Has the same type as input "var".
  634. *
  635. *@par Third-party framework compatibility
  636. *Compatible with the TensorFlow operator ApplyPowerSign.
  637. *
  638. */
  639. REG_OP(ApplyPowerSignD)
  640. .INPUT(var, TensorType::NumberType())
  641. .INPUT(m, TensorType::NumberType())
  642. .INPUT(lr, TensorType::NumberType())
  643. .INPUT(logbase, TensorType::NumberType())
  644. .INPUT(sign_decay, TensorType::NumberType())
  645. .INPUT(beta, TensorType::NumberType())
  646. .INPUT(grad, TensorType::NumberType())
  647. .OUTPUT(var, TensorType::NumberType())
  648. .OUTPUT(m, TensorType::NumberType())
  649. .ATTR(use_locking, Bool, false)
  650. .OP_END_FACTORY_REG(ApplyPowerSignD)
  651. /**
  652. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n
  653. * prox_v = var - alpha * delta\n
  654. * var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
  655. *
  656. *@attention Constraints:\n
  657. * the input tensors must have the same shape.
  658. *
  659. *@par Inputs:
  660. *@li var: A mutable tensor. Should be from a Variable().
  661. *@li alpha: A scalar. Has the same type as "var".
  662. *@li l1: A scalar. Has the same type as "var".
  663. *@li l2: A scalar. Has the same type as "var".
  664. *@li delta: A tensor. Has the same type as "var". The change.
  665. *
  666. *@par Attributes:
  667. * use_locking: An optional bool. Defaults to "False".
  668. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  669. * by a lock; otherwise the behavior is undefined, but may exhibit less
  670. * contention.
  671. *
  672. *@par Outputs:
  673. * var: A mutable tensor. Has the same type as input "var".
  674. *
  675. *@par Third-party framework compatibility
  676. *Compatible with the TensorFlow operator ApplyProximalGradientDescent.
  677. *
  678. */
  679. REG_OP(ApplyProximalGradientDescent)
  680. .INPUT(var, TensorType::NumberType())
  681. .INPUT(alpha, TensorType::NumberType())
  682. .INPUT(l1, TensorType::NumberType())
  683. .INPUT(l2, TensorType::NumberType())
  684. .INPUT(delta, TensorType::NumberType())
  685. .OUTPUT(var, TensorType::NumberType())
  686. .ATTR(use_locking, Bool, false)
  687. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  688. /**
  689. *@brief Updates "var" according to the AddSign update.
  690. *@par Inputs:
  691. *Seven inputs, including:
  692. * @li var: A mutable Tensor of type TensorType::NumberType().
  693. * Should be a Variable Tensor.
  694. * @li m: A mutable Tensor of the same type as "var".
  695. * Should be a Variable Tensor.
  696. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  697. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  698. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  699. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  700. * @li grad: A Tensor of the same type as "var", for the gradient.
  701. *@par Attributes:
  702. *use_locking: An optional bool. Defaults to "False".
  703. * If "True", updating of the "var" and "m" tensors will be
  704. * protected by a lock; otherwise the behavior is undefined,
  705. * but may exhibit less contention.
  706. *@par Outputs:
  707. *var: A mutable Tensor. Has the same type as "var".
  708. *@par Third-party framework compatibility
  709. * Compatible with the TensorFlow operator ApplyAddSign.
  710. */
  711. REG_OP(ApplyAddSign)
  712. .INPUT(var, TensorType::NumberType())
  713. .INPUT(m, TensorType::NumberType())
  714. .INPUT(lr, TensorType::NumberType())
  715. .INPUT(alpha, TensorType::NumberType())
  716. .INPUT(sign_decay, TensorType::NumberType())
  717. .INPUT(beta, TensorType::NumberType())
  718. .INPUT(grad, TensorType::NumberType())
  719. .OUTPUT(var, TensorType::NumberType())
  720. .ATTR(use_locking, Bool, false)
  721. .OP_END_FACTORY_REG(ApplyAddSign)
  722. /**
  723. *@brief Updates "var" according to the AddSign update.
  724. *@par Inputs:
  725. *Seven inputs, including:
  726. * @li var: A mutable Tensor of type TensorType::NumberType().
  727. * Should be a Variable Tensor.
  728. * @li m: A mutable Tensor of the same type as "var".
  729. * Should be a Variable Tensor.
  730. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  731. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  732. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  733. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  734. * @li grad: A Tensor of the same type as "var", for the gradient.
  735. *@par Attributes:
  736. *use_locking: An optional bool. Defaults to "False".
  737. * If "True", updating of the "var" and "m" tensors will be
  738. * protected by a lock; otherwise the behavior is undefined,
  739. * but may exhibit less contention.
  740. *@par Outputs:
  741. *@li var: A mutable Tensor. Has the same type as "var".
  742. *@li m: A mutable Tensor. Has the same type as "m".
  743. *@par Third-party framework compatibility
  744. * Compatible with the TensorFlow operator ApplyAddSign.
  745. */
  746. REG_OP(ApplyAddSignD)
  747. .INPUT(var, TensorType::NumberType())
  748. .INPUT(m, TensorType::NumberType())
  749. .INPUT(lr, TensorType::NumberType())
  750. .INPUT(alpha, TensorType::NumberType())
  751. .INPUT(sign_decay, TensorType::NumberType())
  752. .INPUT(beta, TensorType::NumberType())
  753. .INPUT(grad, TensorType::NumberType())
  754. .OUTPUT(var, TensorType::NumberType())
  755. .OUTPUT(m, TensorType::NumberType())
  756. .ATTR(use_locking, Bool, false)
  757. .OP_END_FACTORY_REG(ApplyAddSignD)
  758. /**
  759. *@brief Updates "var" according to the centered RMSProp algorithm.
  760. * The centered RMSProp algorithm uses an estimate of the centered second moment
  761. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  762. * uses the (uncentered) second moment. This often helps with training, but is
  763. * slightly more expensive in terms of computation and memory.
  764. *
  765. * t-1 mean previous period.
  766. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  767. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  768. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  769. * var <- var - mom\n
  770. *
  771. *@attention Constraints:
  772. *@li in dense implementation of this algorithm, mg, ms, and mom will
  773. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  774. * and mom will not update in iterations during which the grad is zero.
  775. *@li the input tensors must have the same shape.
  776. *
  777. *@par Inputs:
  778. *@li var: A mutable tensor. Should be from a Variable().
  779. *@li mg: A mutable tensor. Has the same type as "var".
  780. * Should be from a Variable().
  781. *@li ms: A mutable tensor. Has the same type as "var".
  782. * Should be from a Variable().
  783. *@li mom: A mutable tensor. Has the same type as "var".
  784. * Should be from a Variable().
  785. *@li lr: A scalar. Has the same type as "var".
  786. *@li rho: A scalar. Has the same type as "var".
  787. *@li momentum: A tensor. Has the same type as "var".
  788. *@li epsilon: A scalar. Has the same type as "var".
  789. *@li grad: A tensor for the gradient. Has the same type as "var".
  790. *
  791. *@par Attributes:
  792. * use_locking: An optional bool. Defaults to "False".
  793. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  794. * by a lock; otherwise the behavior is undefined, but may exhibit less
  795. * contention.
  796. *
  797. *@par Outputs:
  798. * var: A mutable tensor. Has the same type as input "var".
  799. *
  800. *@par Third-party framework compatibility
  801. *Compatible with the TensorFlow operator ApplyCenteredRMSProp.
  802. *
  803. */
  804. REG_OP(ApplyCenteredRMSProp)
  805. .INPUT(var, TensorType::NumberType())
  806. .INPUT(mg, TensorType::NumberType())
  807. .INPUT(ms, TensorType::NumberType())
  808. .INPUT(mom, TensorType::NumberType())
  809. .INPUT(lr, TensorType::NumberType())
  810. .INPUT(rho, TensorType::NumberType())
  811. .INPUT(momentum, TensorType::NumberType())
  812. .INPUT(epsilon, TensorType::NumberType())
  813. .INPUT(grad, TensorType::NumberType())
  814. .OUTPUT(var, TensorType::NumberType())
  815. .ATTR(use_locking, Bool, false)
  816. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  817. /**
  818. *@brief Updates "var" according to the centered RMSProp algorithm.
  819. * The centered RMSProp algorithm uses an estimate of the centered second moment
  820. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  821. * uses the (uncentered) second moment. This often helps with training, but is
  822. * slightly more expensive in terms of computation and memory.
  823. *
  824. * t-1 mean previous period.
  825. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  826. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  827. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  828. * var <- var - mom\n
  829. *
  830. *@attention Constraints:
  831. *@li in dense implementation of this algorithm, mg, ms, and mom will
  832. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  833. * and mom will not update in iterations during which the grad is zero.
  834. *@li the input tensors must have the same shape.
  835. *
  836. *@par Inputs:
  837. *@li var: A mutable tensor. Should be from a Variable().
  838. *@li mg: A mutable tensor. Has the same type as "var".
  839. * Should be from a Variable().
  840. *@li ms: A mutable tensor. Has the same type as "var".
  841. * Should be from a Variable().
  842. *@li mom: A mutable tensor. Has the same type as "var".
  843. * Should be from a Variable().
  844. *@li lr: A scalar. Has the same type as "var".
  845. *@li rho: A scalar. Has the same type as "var".
  846. *@li momentum: A tensor. Has the same type as "var".
  847. *@li epsilon: A scalar. Has the same type as "var".
  848. *@li grad: A tensor for the gradient. Has the same type as "var".
  849. *
  850. *@par Attributes:
  851. * use_locking: An optional bool. Defaults to "False".
  852. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  853. * by a lock; otherwise the behavior is undefined, but may exhibit less
  854. * contention.
  855. *
  856. *@par Outputs:
  857. *@li var: A mutable Tensor. Has the same type as "var".
  858. *@li mg: A mutable Tensor. Has the same type as "mg".
  859. *@li ms: A mutable Tensor. Has the same type as "ms".
  860. *@li mom: A mutable Tensor. Has the same type as "mom".
  861. *@par Third-party framework compatibility
  862. *Compatible with the TensorFlow operator ApplyCenteredRMSPropD.
  863. *
  864. */
  865. REG_OP(ApplyCenteredRMSPropD)
  866. .INPUT(var, TensorType::NumberType())
  867. .INPUT(mg, TensorType::NumberType())
  868. .INPUT(ms, TensorType::NumberType())
  869. .INPUT(mom, TensorType::NumberType())
  870. .INPUT(lr, TensorType::NumberType())
  871. .INPUT(rho, TensorType::NumberType())
  872. .INPUT(momentum, TensorType::NumberType())
  873. .INPUT(epsilon, TensorType::NumberType())
  874. .INPUT(grad, TensorType::NumberType())
  875. .OUTPUT(var, TensorType::NumberType())
  876. .OUTPUT(mg, TensorType::NumberType())
  877. .OUTPUT(ms, TensorType::NumberType())
  878. .OUTPUT(mom, TensorType::NumberType())
  879. .ATTR(use_locking, Bool, false)
  880. .OP_END_FACTORY_REG(ApplyCenteredRMSPropD)
  881. /**
  882. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.
  883. * var -= delta * alpha
  884. *
  885. *@attention Constraints:
  886. * the input tensors must have the same shape.
  887. *
  888. *@par Inputs:
  889. *@li var: A mutable tensor. Should be from a Variable().
  890. *@li alpha: A scalar. Has the same type as "var".
  891. *@li delta: A tensor for the change. Has the same type as "var".
  892. *
  893. *@par Attributes:
  894. * use_locking: An optional bool. Defaults to "False".
  895. * If "True", updating of the "var" tensors is protected
  896. * by a lock; otherwise the behavior is undefined, but may exhibit less
  897. * contention.
  898. *
  899. *@par Outputs:
  900. * var: A mutable tensor. Has the same type as input "var".
  901. *
  902. *@par Third-party framework compatibility
  903. *Compatible with the TensorFlow operator ApplyGradientDescent.
  904. *
  905. */
  906. REG_OP(ApplyGradientDescent)
  907. .INPUT(var, TensorType::NumberType())
  908. .INPUT(alpha, TensorType::NumberType())
  909. .INPUT(delta, TensorType::NumberType())
  910. .OUTPUT(var, TensorType::NumberType())
  911. .ATTR(use_locking, Bool, false)
  912. .OP_END_FACTORY_REG(ApplyGradientDescent)
  913. /**
  914. *@brief Updates "var" according to the adagrad scheme.
  915. * accum += grad * grad\n
  916. * var -= lr * grad * (1 / sqrt(accum))
  917. *
  918. *@attention Constraints:
  919. * the input tensors must have the same shape.
  920. *
  921. *@par Inputs:
  922. *@li var: A mutable tensor. Should be from a Variable().
  923. *@li accum: A mutable tensor. Has the same type as "var".
  924. * Should be from a Variable().
  925. *@li lr: A scalar. Has the same type as "var".
  926. *@li grad: A tensor for the gradient. Has the same type as "var".
  927. *
  928. *@par Attributes:
  929. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  930. *@li use_locking: An optional bool. Defaults to "False".
  931. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  932. * by a lock; otherwise the behavior is undefined, but may exhibit less
  933. * contention.
  934. *
  935. *@par Outputs:
  936. * var: A mutable tensor. Has the same type as input "var".
  937. *
  938. *@par Third-party framework compatibility
  939. *Compatible with the TensorFlow operator ApplyAdagrad.
  940. *
  941. */
  942. REG_OP(ApplyAdagrad)
  943. .INPUT(var, TensorType::NumberType())
  944. .INPUT(accum, TensorType::NumberType())
  945. .INPUT(lr, TensorType::NumberType())
  946. .INPUT(grad, TensorType::NumberType())
  947. .OUTPUT(var, TensorType::NumberType())
  948. .ATTR(update_slots, Bool, true)
  949. .ATTR(use_locking, Bool, false)
  950. .OP_END_FACTORY_REG(ApplyAdagrad)
  951. /**
  952. *@brief Updates "var" according to the adagrad scheme.
  953. * accum += grad * grad\n
  954. * var -= lr * grad * (1 / sqrt(accum))
  955. *
  956. *@attention Constraints:
  957. * the input tensors must have the same shape.
  958. *
  959. *@par Inputs:
  960. *@li var: A mutable tensor. Should be from a Variable().
  961. *@li accum: A mutable tensor. Has the same type as "var".
  962. * Should be from a Variable().
  963. *@li lr: A scalar. Has the same type as "var".
  964. *@li grad: A tensor for the gradient. Has the same type as "var".
  965. *
  966. *@par Attributes:
  967. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  968. *@li use_locking: An optional bool. Defaults to "False".
  969. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  970. * by a lock; otherwise the behavior is undefined, but may exhibit less
  971. * contention.
  972. *
  973. *@par Outputs:
  974. *@li var: A mutable tensor. Has the same type as input "var".
  975. *@li accum: A mutable tensor. Has the same type as input "var".
  976. *
  977. *@par Third-party framework compatibility
  978. *Compatible with the TensorFlow operator ApplyAdagrad.
  979. *
  980. */
  981. REG_OP(ApplyAdagradD)
  982. .INPUT(var, TensorType::NumberType())
  983. .INPUT(accum, TensorType::NumberType())
  984. .INPUT(lr, TensorType::NumberType())
  985. .INPUT(grad, TensorType::NumberType())
  986. .OUTPUT(var, TensorType::NumberType())
  987. .OUTPUT(accum, TensorType::NumberType())
  988. .ATTR(update_slots, Bool, true)
  989. .ATTR(use_locking, Bool, false)
  990. .OP_END_FACTORY_REG(ApplyAdagradD)
  991. /**
  992. * @brief Updates "var" according to the adagradv2 scheme.
  993. * accum += grad * grad \n
  994. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  995. *
  996. * @par Inputs:
  997. * @li var: A mutable tensor. Must be one of the data types defined in
  998. * TensorType::NumberType(). Should be from a Variable().
  999. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  1000. * Variable().
  1001. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  1002. * from a Variable().
  1003. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  1004. * from a Variable().
  1005. * @li epsilon: A scalar. Has the same type as "var".
  1006. *
  1007. * @par Attributes:
  1008. * @li update_slots: An optional bool. Defaults to "True".
  1009. * If "True", "accum" will be updated
  1010. * @li use_locking: An optional bool. Defaults to "False".
  1011. * If "True", updating of the "var" tensor is protected by a lock;
  1012. * otherwise the behavior is undefined, but may exhibit less contention.
  1013. *
  1014. * @par Outputs:
  1015. * var: A mutable tensor. Has the same type as input "var".
  1016. *
  1017. * @attention Constraints:
  1018. * The input tensors must have the same shape.
  1019. *
  1020. * @par Third-party framework compatibility
  1021. * Compatible with the TensorFlow operator ApplyAdagrad.
  1022. *
  1023. */
  1024. REG_OP(ApplyAdagradV2)
  1025. .INPUT(var, TensorType::NumberType())
  1026. .INPUT(accum, TensorType::NumberType())
  1027. .INPUT(lr, TensorType::NumberType())
  1028. .INPUT(epsilon, TensorType::NumberType())
  1029. .INPUT(grad, TensorType::NumberType())
  1030. .OUTPUT(var, TensorType::NumberType())
  1031. .ATTR(update_slots, Bool, true)
  1032. .ATTR(use_locking, Bool, false)
  1033. .OP_END_FACTORY_REG(ApplyAdagradV2)
  1034. /**
  1035. * @brief Updates "var" according to the adagradv2 scheme.
  1036. * accum += grad * grad \n
  1037. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  1038. *
  1039. * @par Inputs:
  1040. * @li var: A mutable tensor. Must be one of the data types defined in
  1041. * TensorType::NumberType(). Should be from a Variable().
  1042. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  1043. * Variable().
  1044. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  1045. * from a Variable().
  1046. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  1047. * from a Variable().
  1048. *
  1049. * @par Attributes:
  1050. * @li epsilon: A scalar. Has the same type as "var".
  1051. * @li update_slots: An optional bool. Defaults to "True".
  1052. * If "True", "accum" will be updated
  1053. * @li use_locking: An optional bool. Defaults to "False".
  1054. * If "True", updating of the "var" tensor is protected by a lock;
  1055. * otherwise the behavior is undefined, but may exhibit less contention.
  1056. *
  1057. * @par Outputs:
  1058. * var: A mutable tensor. Has the same type as input "var".
  1059. *
  1060. * @attention Constraints:
  1061. * The input tensors must have the same shape.
  1062. *
  1063. * @par Third-party framework compatibility
  1064. * Compatible with the TensorFlow operator ApplyAdagrad.
  1065. *
  1066. */
  1067. REG_OP(ApplyAdagradV2D)
  1068. .INPUT(var, TensorType::NumberType())
  1069. .INPUT(accum, TensorType::NumberType())
  1070. .INPUT(lr, TensorType::NumberType())
  1071. .INPUT(grad, TensorType::NumberType())
  1072. .OUTPUT(var, TensorType::NumberType())
  1073. .OUTPUT(accum, TensorType::NumberType())
  1074. .REQUIRED_ATTR(epsilon, Float)
  1075. .ATTR(update_slots, Bool, true)
  1076. .ATTR(use_locking, Bool, false)
  1077. .OP_END_FACTORY_REG(ApplyAdagradV2D)
  1078. /**
  1079. *@brief Updates "var" according to the proximal adagrad scheme.
  1080. *@par Inputs:
  1081. *Eight inputs, including:
  1082. * @li var: A mutable Tensor. Must be one of the following types:
  1083. * TensorType::NumberType(). Should be a Variable Tensor.
  1084. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1085. * type as "var". Should be a Variable Tensor.
  1086. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1087. * Should be a Variable Tensor.
  1088. * @li grad: A Tensor of the same type as "var", for the gradient.
  1089. * @li lr: A Tensor of the same type as "var".
  1090. * Scaling factor. Must be a scalar.
  1091. * @li l1: A Tensor of the same type as "var".
  1092. * L1 regulariation. Must be a scalar.
  1093. * @li l2: A Tensor of the same type as "var".
  1094. * L2 regulariation. Must be a scalar.
  1095. * @li global_step: A Tensor of type int32 or int64.
  1096. * Training step number. Must be a scalar.
  1097. *@par Attributes:
  1098. *use_locking: An optional bool. Defaults to "False".
  1099. * If "True", updating of the var and accum tensors will be
  1100. * protected by a lock; otherwise the behavior is undefined,
  1101. * but may exhibit less contention.
  1102. *@par Outputs:
  1103. *var: A mutable Tensor. Has the same type as "var".
  1104. *@par Third-party framework compatibility
  1105. *Compatible with the TensorFlow operator ApplyAdagradDA.
  1106. */
  1107. REG_OP(ApplyAdagradDA)
  1108. .INPUT(var, TensorType::NumberType())
  1109. .INPUT(gradient_accumulator, TensorType::NumberType())
  1110. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1111. .INPUT(grad, TensorType::NumberType())
  1112. .INPUT(lr, TensorType::NumberType())
  1113. .INPUT(l1, TensorType::NumberType())
  1114. .INPUT(l2, TensorType::NumberType())
  1115. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1116. .OUTPUT(var, TensorType::NumberType())
  1117. .ATTR(use_locking, Bool, false)
  1118. .OP_END_FACTORY_REG(ApplyAdagradDA)
  1119. /**
  1120. *@brief Updates "var" according to the proximal adagrad scheme.
  1121. *@par Inputs:
  1122. *Eight inputs, including:
  1123. * @li var: A mutable Tensor. Must be one of the following types:
  1124. * TensorType::NumberType(). Should be a Variable Tensor.
  1125. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1126. * type as "var". Should be a Variable Tensor.
  1127. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1128. * Should be a Variable Tensor.
  1129. * @li grad: A Tensor of the same type as "var", for the gradient.
  1130. * @li lr: A Tensor of the same type as "var".
  1131. * Scaling factor. Must be a scalar.
  1132. * @li l1: A Tensor of the same type as "var".
  1133. * L1 regulariation. Must be a scalar.
  1134. * @li l2: A Tensor of the same type as "var".
  1135. * L2 regulariation. Must be a scalar.
  1136. * @li global_step: A Tensor of type int32 or int64.
  1137. * Training step number. Must be a scalar.
  1138. *@par Attributes:
  1139. *use_locking: An optional bool. Defaults to "False".
  1140. * If "True", updating of the var and accum tensors will be
  1141. * protected by a lock; otherwise the behavior is undefined,
  1142. * but may exhibit less contention.
  1143. *@par Outputs:
  1144. *var: A mutable Tensor. Has the same type as "var".
  1145. *gradient_accumulator: A mutable Tensor. Has the same type as "var".
  1146. *gradient_squared_accumulator: A mutable Tensor. Has the same type as "var".
  1147. *@par Third-party framework compatibility
  1148. *Compatible with the TensorFlow operator ApplyAdagradDA.
  1149. */
  1150. REG_OP(ApplyAdagradDAD)
  1151. .INPUT(var, TensorType::NumberType())
  1152. .INPUT(gradient_accumulator, TensorType::NumberType())
  1153. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1154. .INPUT(grad, TensorType::NumberType())
  1155. .INPUT(lr, TensorType::NumberType())
  1156. .INPUT(l1, TensorType::NumberType())
  1157. .INPUT(l2, TensorType::NumberType())
  1158. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1159. .OUTPUT(var, TensorType::NumberType())
  1160. .OUTPUT(gradient_accumulator, TensorType::NumberType())
  1161. .OUTPUT(gradient_squared_accumulator, TensorType::NumberType())
  1162. .ATTR(use_locking, Bool, false)
  1163. .OP_END_FACTORY_REG(ApplyAdagradDAD)
  1164. /**
  1165. *@brief Returns the dimension index in the destination data format given the one in
  1166. * the source data format.
  1167. *
  1168. *@par Inputs:
  1169. * x: A tensor of type int32 or int64.
  1170. * A Tensor with each element as a dimension index in source data format.
  1171. * Must be in the range [-4, 4).
  1172. *
  1173. *@par Attributes:
  1174. *@li src_format: An optional string. Defaults to NHWC.
  1175. * source data format. Must of length 4.
  1176. *@li dst_format: An optional string. Defaults to NCHW.
  1177. * destination data format. Must of length 4.
  1178. *
  1179. *@par Outputs:
  1180. * y: A tensor. Has the same type as "x". Must be in the range [0, 4).
  1181. *
  1182. *@par Third-party framework compatibility
  1183. *Compatible with the TensorFlow operator DataFormatDimMap.
  1184. *
  1185. */
  1186. REG_OP(DataFormatDimMap)
  1187. .INPUT(x, TensorType::IndexNumberType())
  1188. .ATTR(src_format, String, "NHWC")
  1189. .ATTR(dst_format, String, "NCHW")
  1190. .OUTPUT(y, TensorType::IndexNumberType())
  1191. .OP_END_FACTORY_REG(DataFormatDimMap)
  1192. /**
  1193. * @brief Implements stochastic gradient descent (optionally with momentum).\n
  1194. * Nesterov momentum is based on the formula from
  1195. * On the importance of initialization and momentum in deep learning.\n
  1196. * @par Inputs:
  1197. * @li parameters: A mutable tensor of type float16 or float32.\n
  1198. * Specifies the iterable of parameters to optimize or dicts defining parameter
  1199. * groups.
  1200. * @li gradient: A tensor of type float16 or float32.\n
  1201. * Specifies the gradient of training step.
  1202. * @li learning_rate: A tensor of type float16 or float32.\n
  1203. * Specifies the learing_rate of training step.
  1204. * @li accum: A tensor of type float16 or float32.
  1205. * Specifies the velocity of training step.
  1206. * @li momentum: A tensor of type float16 or float32.
  1207. * Specifies the momentum factor.
  1208. * @li stat: A tensor of type float16 or float32.
  1209. * Specifies the status representing the first step or not.
  1210. * @par Attributes:
  1211. * @li dampening: An optional float, specifying the dampening for momentum.
  1212. * Defaults to "0.0".
  1213. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  1214. * "0.0".
  1215. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  1216. * momentum. Defaults to "False".
  1217. * @par Outputs:
  1218. * parameters: A mutable tensor same as input "parameters".
  1219. * @see ApplyMomentum()
  1220. * @par Third-party framework compatibility
  1221. * @li Compatible with the PyTorch operator SGD.
  1222. */
  1223. REG_OP(SGD)
  1224. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1225. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  1226. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  1227. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  1228. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  1229. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  1230. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1231. .ATTR(dampening, Float, 0.0)
  1232. .ATTR(weight_decay, Float, 0.0)
  1233. .ATTR(nesterov, Bool, false)
  1234. .OP_END_FACTORY_REG(SGD)
  1235. /**
  1236. * @brief Updates "var" according to the RMSProp algorithm.
  1237. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  1238. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  1239. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  1240. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  1241. * var <- var - mom\n
  1242. *
  1243. * @par Inputs:
  1244. * @li var: A mutable tensor. Must be one of the data types defined in
  1245. * TensorType::NumberType(). Should be from a Variable().
  1246. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1247. * Variable().
  1248. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1249. * Variable().
  1250. * @li lr: A scalar. Must have the same type as "var".
  1251. * @li rho: A scalar. Must have the same type as "var".
  1252. * @li momentum: A scalar. Must have the same type as "var".
  1253. * @li epsilon: A scalar. Must have the same type as "var".
  1254. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1255. *
  1256. * @par Attributes:
  1257. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  1258. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  1259. * behavior is undefined, but may exhibit less contention.
  1260. *
  1261. * @par Outputs:
  1262. * var: A mutable tensor. Has the same type as input "var".
  1263. *
  1264. * @attention Constraints:
  1265. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will
  1266. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"
  1267. * will not update in iterations during which "grad" is 0.
  1268. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1269. *
  1270. * @par Third-party framework compatibility
  1271. * @li Compatible with the TensorFlow operator ApplyRMSProp.
  1272. */
  1273. REG_OP(ApplyRMSProp)
  1274. .INPUT(var, TensorType::NumberType())
  1275. .INPUT(ms, TensorType::NumberType())
  1276. .INPUT(mom, TensorType::NumberType())
  1277. .INPUT(lr, TensorType::NumberType())
  1278. .INPUT(rho, TensorType::NumberType())
  1279. .INPUT(momentum, TensorType::NumberType())
  1280. .INPUT(epsilon, TensorType::NumberType())
  1281. .INPUT(grad, TensorType::NumberType())
  1282. .OUTPUT(var, TensorType::NumberType())
  1283. .ATTR(use_locking, Bool, false)
  1284. .OP_END_FACTORY_REG(ApplyRMSProp)
  1285. /**
  1286. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  1287. * considered as an attribute.
  1288. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  1289. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  1290. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  1291. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  1292. * var <- var - mom
  1293. *
  1294. * @par Inputs:
  1295. * @li var: A mutable tensor. Must be one of the data types defined in
  1296. * TensorType::NumberType(). Should be from a Variable().
  1297. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1298. * Variable().
  1299. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1300. * Variable().
  1301. * @li lr: A scalar. Must have the same type as "var".
  1302. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1303. *
  1304. * @par Attributes:
  1305. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating
  1306. * of the "var", "ms", and "mom" tensors will be protected by a lock;
  1307. * otherwise the behavior is undefined, but may exhibit less contention.
  1308. * @li rho: A required scalar. Must have the same type as "var".
  1309. * @li momentum: A required scalar. Must have the same type as "var".
  1310. * @li epsilon: A required scalar. Must have the same type as "var".
  1311. *
  1312. * @par Outputs:
  1313. * var: A mutable tensor. Must have the same type as input "var".
  1314. *
  1315. * @attention Constraints:
  1316. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will
  1317. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"
  1318. * will not update in iterations during which "grad" is 0.
  1319. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1320. *
  1321. * @par Third-party framework compatibility
  1322. * @li Compatible with the TensorFlow operator ApplyRMSProp.
  1323. */
  1324. REG_OP(ApplyRMSPropD)
  1325. .INPUT(var, TensorType::NumberType())
  1326. .INPUT(ms, TensorType::NumberType())
  1327. .INPUT(mom, TensorType::NumberType())
  1328. .INPUT(lr, TensorType::NumberType())
  1329. .INPUT(grad, TensorType::NumberType())
  1330. .OUTPUT(var, TensorType::NumberType())
  1331. .OUTPUT(ms, TensorType::NumberType())
  1332. .OUTPUT(mom, TensorType::NumberType())
  1333. .REQUIRED_ATTR(rho, Float)
  1334. .REQUIRED_ATTR(momentum, Float)
  1335. .REQUIRED_ATTR(epsilon, Float)
  1336. .ATTR(use_locking, Bool, false)
  1337. .OP_END_FACTORY_REG(ApplyRMSPropD)
  1338. /**
  1339. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  1340. *@par Inputs:
  1341. *Six inputs, including:
  1342. * @li var: A mutable Tensor of type TensorType::NumberType().
  1343. * Should be from a Variable().
  1344. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1345. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1346. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1347. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1348. * @li grad: A Tensor of the same type as "var", for the gradient.
  1349. *@par Attributes:
  1350. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  1351. *@par Outputs:
  1352. *var: A mutable tensor. Must have the same type as input "var".
  1353. *@par Third-party framework compatibility
  1354. *Compatible with the TensorFlow operator ApplyProximalAdagrad.
  1355. */
  1356. REG_OP(ApplyProximalAdagrad)
  1357. .INPUT(var, TensorType::NumberType())
  1358. .INPUT(accum, TensorType::NumberType())
  1359. .INPUT(lr, TensorType::NumberType())
  1360. .INPUT(l1, TensorType::NumberType())
  1361. .INPUT(l2, TensorType::NumberType())
  1362. .INPUT(grad, TensorType::NumberType())
  1363. .OUTPUT(var, TensorType::NumberType())
  1364. .ATTR(use_locking, Bool, false)
  1365. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  1366. /**
  1367. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  1368. *@par Inputs:
  1369. *Six inputs, including:
  1370. * @li var: A mutable Tensor of type TensorType::NumberType().
  1371. * Should be from a Variable().
  1372. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1373. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1374. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1375. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1376. * @li grad: A Tensor of the same type as "var", for the gradient.
  1377. *@par Attributes:
  1378. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  1379. *@par Outputs:
  1380. * @li var: A mutable Tensor. Has the same type as "var".
  1381. * @li accum: A mutable Tensor. Has the same type as "var".
  1382. *@par Third-party framework compatibility
  1383. *Compatible with the TensorFlow operator ApplyProximalAdagradD.
  1384. */
  1385. REG_OP(ApplyProximalAdagradD)
  1386. .INPUT(var, TensorType::NumberType())
  1387. .INPUT(accum, TensorType::NumberType())
  1388. .INPUT(lr, TensorType::NumberType())
  1389. .INPUT(l1, TensorType::NumberType())
  1390. .INPUT(l2, TensorType::NumberType())
  1391. .INPUT(grad, TensorType::NumberType())
  1392. .OUTPUT(var, TensorType::NumberType())
  1393. .OUTPUT(accum, TensorType::NumberType())
  1394. .ATTR(use_locking, Bool, false)
  1395. .OP_END_FACTORY_REG(ApplyProximalAdagradD)
  1396. /**
  1397. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.
  1398. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1399. * Only the indices into the first dimensions of "var" and "accum" are updated.
  1400. *@par Inputs:
  1401. * Seven inputs, including:\n
  1402. * @li var: A mutable Tensor.
  1403. * TensorType::NumberType(). Should be a Variable Tensor.
  1404. * @li accum: A mutable Tensor of the same type as "var".
  1405. * Should be a Variable Tensor.
  1406. * @li lr: A Tensor of the same type as "var".
  1407. * Scaling factor. Must be a scalar.
  1408. * @li l1: A Tensor of the same type as "var".
  1409. * L1 regulariation. Must be a scalar.
  1410. * @li l2: A Tensor of the same type as "var".
  1411. * L2 regulariation. Must be a scalar.
  1412. * @li grad: A Tensor. Has the same type as "var".
  1413. * The gradient.
  1414. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  1415. * TensorType::IndexNumberType().
  1416. *@par Attributes:
  1417. *use_locking: An optional bool. Defaults to "False".\n
  1418. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  1419. * If "False", the behavior is undefined, but may exhibit less contention.
  1420. *@par Outputs:
  1421. *var: A mutable Tensor. Has the same type as "var".
  1422. *@par Third-party framework compatibility
  1423. *Compatible with the TensorFlow operator SparseApplyProximalAdagrad.
  1424. */
  1425. REG_OP(SparseApplyProximalAdagrad)
  1426. .INPUT(var, TensorType::NumberType())
  1427. .INPUT(accum, TensorType::NumberType())
  1428. .INPUT(lr, TensorType::NumberType())
  1429. .INPUT(l1, TensorType::NumberType())
  1430. .INPUT(l2, TensorType::NumberType())
  1431. .INPUT(grad, TensorType::NumberType())
  1432. .INPUT(indices, TensorType::IndexNumberType())
  1433. .OUTPUT(var, TensorType::NumberType())
  1434. .ATTR(use_locking, Bool, false)
  1435. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  1436. /**
  1437. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  1438. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1439. * Only the indices into the first dimensions of "var" and "accum" are updated.
  1440. *@par Inputs:
  1441. * Seven inputs, including:\n
  1442. * @li var: A mutable Tensor.\n
  1443. * TensorType::NumberType(). Should be a Variable Tensor.
  1444. * @li accum: A mutable Tensor of the same type as "var".\n
  1445. * Should be a Variable Tensor.
  1446. * @li lr: A Tensor of the same type as "var".\n
  1447. * Scaling factor. Must be a scalar.
  1448. * @li l1: A Tensor of the same type as "var".\n
  1449. * L1 regulariation. Must be a scalar.
  1450. * @li l2: A Tensor of the same type as "var".\n
  1451. * L2 regulariation. Must be a scalar.
  1452. * @li grad: A Tensor. Has the same type as "var". \n
  1453. * The gradient.
  1454. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  1455. * TensorType::IndexNumberType().
  1456. *@par Attributes:
  1457. *use_locking: An optional bool. Defaults to "False".\n
  1458. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  1459. * If "False", the behavior is undefined, but may exhibit less contention.
  1460. *@par Outputs:
  1461. *@li var: A mutable Tensor. Has the same type as "var".
  1462. *@li accum: A mutable Tensor. Has the same type as "var".
  1463. *@par Third-party framework compatibility
  1464. *Compatible with the TensorFlow operator SparseApplyProximalAdagrad.
  1465. */
  1466. REG_OP(SparseApplyProximalAdagradD)
  1467. .INPUT(var, TensorType::NumberType())
  1468. .INPUT(accum, TensorType::NumberType())
  1469. .INPUT(lr, TensorType::NumberType())
  1470. .INPUT(l1, TensorType::NumberType())
  1471. .INPUT(l2, TensorType::NumberType())
  1472. .INPUT(grad, TensorType::NumberType())
  1473. .INPUT(indices, TensorType::IndexNumberType())
  1474. .OUTPUT(var, TensorType::NumberType())
  1475. .OUTPUT(accum, TensorType::NumberType())
  1476. .ATTR(use_locking, Bool, false)
  1477. .OP_END_FACTORY_REG(SparseApplyProximalAdagradD)
  1478. /**
  1479. *@brief Updates "var" according to the Ftrl-proximal scheme.
  1480. *@par Inputs:
  1481. *Eight inputs, including:
  1482. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1483. * Should be a Variable Tensor.
  1484. * @li accum: A mutable Tensor of the same type as "var".
  1485. * Should be a Variable Tensor.
  1486. * @li linear: A mutable Tensor of the same type as "var".
  1487. * Should be a Variable Tensor.
  1488. * @li grad: A Tensor of the same type as "var", for the gradient.
  1489. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1490. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1491. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1492. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1493. *@par Attributes:
  1494. *use_locking: An optional bool. Defaults to "False".
  1495. * If "True", updating of the "var" and "accum" tensors will be
  1496. * protected by a lock; otherwise the behavior is undefined,
  1497. * but may exhibit less contention.
  1498. *@par Outputs:
  1499. *var: A mutable Tensor. Has the same type as "var".
  1500. *@par Third-party framework compatibility
  1501. *Compatible with the TensorFlow operator ApplyFtrl.
  1502. */
  1503. REG_OP(ApplyFtrl)
  1504. .INPUT(var, TensorType::NumberType())
  1505. .INPUT(accum, TensorType::NumberType())
  1506. .INPUT(linear, TensorType::NumberType())
  1507. .INPUT(grad, TensorType::NumberType())
  1508. .INPUT(lr, TensorType::NumberType())
  1509. .INPUT(l1, TensorType::NumberType())
  1510. .INPUT(l2, TensorType::NumberType())
  1511. .INPUT(lr_power, TensorType::NumberType())
  1512. .OUTPUT(var, TensorType::NumberType())
  1513. .ATTR(use_locking, Bool, false)
  1514. .OP_END_FACTORY_REG(ApplyFtrl)
  1515. /**
  1516. *@brief Updates "var" according to the Ftrl-proximal scheme.
  1517. *@par Inputs:
  1518. *Eight inputs, including:
  1519. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1520. * Should be a Variable Tensor.
  1521. * @li accum: A mutable Tensor of the same type as "var".
  1522. * Should be a Variable Tensor.
  1523. * @li linear: A mutable Tensor of the same type as "var".
  1524. * Should be a Variable Tensor.
  1525. * @li grad: A Tensor of the same type as "var", for the gradient.
  1526. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1527. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1528. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1529. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1530. *@par Attributes:
  1531. *use_locking: An optional bool. Defaults to "False".
  1532. * If "True", updating of the "var" and "accum" tensors will be
  1533. * protected by a lock; otherwise the behavior is undefined,
  1534. * but may exhibit less contention.
  1535. *@par Outputs:
  1536. *@li var: A mutable Tensor. Has the same type as "var".
  1537. *@li accum: A mutable Tensor. Has the same type as "accum".
  1538. *@li linear: A mutable Tensor. Has the same type as "linear".
  1539. *@par Third-party framework compatibility
  1540. *Compatible with the TensorFlow operator ApplyFtrl.
  1541. */
  1542. REG_OP(ApplyFtrlD)
  1543. .INPUT(var, TensorType::NumberType())
  1544. .INPUT(accum, TensorType::NumberType())
  1545. .INPUT(linear, TensorType::NumberType())
  1546. .INPUT(grad, TensorType::NumberType())
  1547. .INPUT(lr, TensorType::NumberType())
  1548. .INPUT(l1, TensorType::NumberType())
  1549. .INPUT(l2, TensorType::NumberType())
  1550. .INPUT(lr_power, TensorType::NumberType())
  1551. .OUTPUT(var, TensorType::NumberType())
  1552. .OUTPUT(accum, TensorType::NumberType())
  1553. .OUTPUT(linear, TensorType::NumberType())
  1554. .ATTR(use_locking, Bool, false)
  1555. .OP_END_FACTORY_REG(ApplyFtrlD)
  1556. /**
  1557. *@brief Update "var" according to the Ftrl-proximal scheme.
  1558. *@par Inputs:
  1559. *Nine inputs, including:
  1560. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1561. * Should be a Variable Tensor.
  1562. * @li accum: A mutable Tensor of the same type as "var".
  1563. * Should be a Variable Tensor.
  1564. * @li linear: A mutable Tensor of the same type as "var".
  1565. * Should be a Variable Tensor.
  1566. * @li grad: A Tensor of the same type as "var", for the gradient.
  1567. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1568. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1569. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1570. * @li l2_shrinkage: A Tensor of the same type as "var".
  1571. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1572. *@par Attributes:
  1573. *use_locking: An optional bool. Defaults to "False".
  1574. * If "True", updating of the "var" and "accum" tensors will be
  1575. * protected by a lock; otherwise the behavior is undefined,
  1576. * but may exhibit less contention.
  1577. *@par Outputs:
  1578. *var: A mutable Tensor. Has the same type as "var".
  1579. *@par Third-party framework compatibility
  1580. *Compatible with the TensorFlow operator ApplyFtrlV2.
  1581. */
  1582. REG_OP(ApplyFtrlV2)
  1583. .INPUT(var, TensorType::NumberType())
  1584. .INPUT(accum, TensorType::NumberType())
  1585. .INPUT(linear, TensorType::NumberType())
  1586. .INPUT(grad, TensorType::NumberType())
  1587. .INPUT(lr, TensorType::NumberType())
  1588. .INPUT(l1, TensorType::NumberType())
  1589. .INPUT(l2, TensorType::NumberType())
  1590. .INPUT(l2_shrinkage, TensorType::NumberType())
  1591. .INPUT(lr_power, TensorType::NumberType())
  1592. .OUTPUT(var, TensorType::NumberType())
  1593. .ATTR(use_locking, Bool, false)
  1594. .OP_END_FACTORY_REG(ApplyFtrlV2)
  1595. /**
  1596. *@brief Update "var" according to the Ftrl-proximal scheme.
  1597. *@par Inputs:
  1598. *Nine inputs, including:
  1599. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1600. * Should be a Variable Tensor.
  1601. * @li accum: A mutable Tensor of the same type as "var".
  1602. * Should be a Variable Tensor.
  1603. * @li linear: A mutable Tensor of the same type as "var".
  1604. * Should be a Variable Tensor.
  1605. * @li grad: A Tensor of the same type as "var", for the gradient.
  1606. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1607. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1608. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1609. * @li l2_shrinkage: A Tensor of the same type as "var".
  1610. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1611. *@par Attributes:
  1612. *use_locking: An optional bool. Defaults to "False".
  1613. * If "True", updating of the "var" and "accum" tensors will be
  1614. * protected by a lock; otherwise the behavior is undefined,
  1615. * but may exhibit less contention.
  1616. *@par Outputs:
  1617. *var: A mutable Tensor. Has the same type as "var".
  1618. *accum: A mutable Tensor. Has the same type as "accum".
  1619. *linear: A mutable Tensor. Has the same type as "linear".
  1620. *@par Third-party framework compatibility
  1621. *Compatible with the TensorFlow operator ApplyFtrlV2.
  1622. */
  1623. REG_OP(ApplyFtrlV2D)
  1624. .INPUT(var, TensorType::NumberType())
  1625. .INPUT(accum, TensorType::NumberType())
  1626. .INPUT(linear, TensorType::NumberType())
  1627. .INPUT(grad, TensorType::NumberType())
  1628. .INPUT(lr, TensorType::NumberType())
  1629. .INPUT(l1, TensorType::NumberType())
  1630. .INPUT(l2, TensorType::NumberType())
  1631. .INPUT(l2_shrinkage, TensorType::NumberType())
  1632. .INPUT(lr_power, TensorType::NumberType())
  1633. .OUTPUT(var, TensorType::NumberType())
  1634. .OUTPUT(accum, TensorType::NumberType())
  1635. .OUTPUT(linear, TensorType::NumberType())
  1636. .ATTR(use_locking, Bool, false)
  1637. .OP_END_FACTORY_REG(ApplyFtrlV2D)
  1638. /**
  1639. *@brief Updates "var" according to the Adam algorithm.
  1640. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  1641. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  1642. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  1643. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1644. *
  1645. *@attention Constraints:
  1646. * *The input tensors must have the same shape.*
  1647. *
  1648. *@par Inputs:
  1649. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1650. * Should be from a Variable().
  1651. *@li m: A mutable Tensor of the same type as "var".
  1652. * Should be from a Variable().
  1653. *@li v: A mutable Tensor of the same type as "var".
  1654. * Should be from a Variable().
  1655. *@li beta1_power: A scalar of the same type as "var".
  1656. *@li beta2_power: A scalar of the same type as "var".
  1657. *@li lr: learning_rate. A scalar of the same type as "var".
  1658. *@li beta1: A scalar of the same type as "var".
  1659. *@li beta2: A scalar of the same type as "var".
  1660. *@li epsilon: A scalar of the same type as "var".
  1661. *@li grad: A Tensor of the same type as "var", for the gradient.
  1662. *
  1663. *@par Attributes:
  1664. *@li use_locking: An optional bool. Defaults to "False".
  1665. * If "True", updating of the "var", m", and "v" tensors will be protected
  1666. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1667. * contention.
  1668. *@li use_nesterov: An optional bool. Defaults to "False".
  1669. If "True", uses the nesterov update.
  1670. *
  1671. *@par Outputs:
  1672. * var: A mutable Tensor. Has the same type as intput "var".
  1673. *@par Third-party framework compatibility
  1674. *Compatible with the TensorFlow operator ApplyAdam.
  1675. */
  1676. REG_OP(ApplyAdam)
  1677. .INPUT(var, TensorType::NumberType())
  1678. .INPUT(m, TensorType::NumberType())
  1679. .INPUT(v, TensorType::NumberType())
  1680. .INPUT(beta1_power, TensorType::NumberType())
  1681. .INPUT(beta2_power, TensorType::NumberType())
  1682. .INPUT(lr, TensorType::NumberType())
  1683. .INPUT(beta1, TensorType::NumberType())
  1684. .INPUT(beta2, TensorType::NumberType())
  1685. .INPUT(epsilon, TensorType::NumberType())
  1686. .INPUT(grad, TensorType::NumberType())
  1687. .OUTPUT(var, TensorType::NumberType())
  1688. .ATTR(use_locking, Bool, false)
  1689. .ATTR(use_nesterov, Bool, false)
  1690. .OP_END_FACTORY_REG(ApplyAdam)
  1691. /**
  1692. *@brief Updates "var" according to the Adam algorithm.
  1693. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  1694. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  1695. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  1696. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1697. *
  1698. *@attention Constraints:
  1699. * *The input tensors must have the same shape.*
  1700. *
  1701. *@par Inputs:
  1702. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1703. * Should be from a Variable().
  1704. *@li m: A mutable Tensor of the same type as "var".
  1705. * Should be from a Variable().
  1706. *@li v: A mutable Tensor of the same type as "var".
  1707. * Should be from a Variable().
  1708. *@li beta1_power: A scalar of the same type as "var".
  1709. *@li beta2_power: A scalar of the same type as "var".
  1710. *@li lr: learning_rate. A scalar of the same type as "var".
  1711. *@li beta1: A scalar of the same type as "var".
  1712. *@li beta2: A scalar of the same type as "var".
  1713. *@li epsilon: A scalar of the same type as "var".
  1714. *@li grad: A Tensor of the same type as "var", for the gradient.
  1715. *
  1716. *@par Attributes:
  1717. *@li use_locking: An optional bool. Defaults to "False".
  1718. * If "True", updating of the "var", m", and "v" tensors will be protected
  1719. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1720. * contention.
  1721. *@li use_nesterov: An optional bool. Defaults to "False".
  1722. If "True", uses the nesterov update.
  1723. *
  1724. *@par Outputs:
  1725. *@li var: A mutable tensor. Has the same type as input "var".
  1726. *@li m: A mutable tensor. Has the same type as input "m".
  1727. *@li v: A mutable tensor. Has the same type as input "v".
  1728. *@par Third-party framework compatibility
  1729. *Compatible with the TensorFlow operator ApplyAdam.
  1730. */
  1731. REG_OP(ApplyAdamD)
  1732. .INPUT(var, TensorType::NumberType())
  1733. .INPUT(m, TensorType::NumberType())
  1734. .INPUT(v, TensorType::NumberType())
  1735. .INPUT(beta1_power, TensorType::NumberType())
  1736. .INPUT(beta2_power, TensorType::NumberType())
  1737. .INPUT(lr, TensorType::NumberType())
  1738. .INPUT(beta1, TensorType::NumberType())
  1739. .INPUT(beta2, TensorType::NumberType())
  1740. .INPUT(epsilon, TensorType::NumberType())
  1741. .INPUT(grad, TensorType::NumberType())
  1742. .OUTPUT(var, TensorType::NumberType())
  1743. .OUTPUT(m, TensorType::NumberType())
  1744. .OUTPUT(v, TensorType::NumberType())
  1745. .ATTR(use_locking, Bool, false)
  1746. .ATTR(use_nesterov, Bool, false)
  1747. .OP_END_FACTORY_REG(ApplyAdamD)
  1748. /**
  1749. *@brief Updates "var" according to the proximal adadelta scheme.
  1750. *@par Inputs:
  1751. *Seven inputs, including:
  1752. * @li var: A mutable Tensor of type TensorType::NumberType().
  1753. * Should be a Variable Tensor.
  1754. * @li accum: A mutable Tensor of the same type as "var".
  1755. * Should be a Variable Tensor.
  1756. * @li accum_update: A mutable Tensor of the same type as "var".
  1757. * Should be a Variable Tensor.
  1758. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1759. * @li rho: A scalar of the same type as "var", for the decay factor.
  1760. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1761. * @li grad: A Tensor of the same type as "var", for the gradient.
  1762. *@par Attributes:
  1763. *use_locking: An optional bool. Defaults to "False".
  1764. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1765. * protected by a lock; otherwise the behavior is undefined,
  1766. * but may exhibit less contention.
  1767. *@par Outputs:
  1768. *var: A mutable Tensor. Has the same type as "var".
  1769. *@par Third-party framework compatibility
  1770. * Compatible with the TensorFlow operator ApplyAdadelta.
  1771. */
  1772. REG_OP(ApplyAdadelta)
  1773. .INPUT(var, TensorType::NumberType())
  1774. .INPUT(accum, TensorType::NumberType())
  1775. .INPUT(accum_update, TensorType::NumberType())
  1776. .INPUT(lr, TensorType::NumberType())
  1777. .INPUT(rho, TensorType::NumberType())
  1778. .INPUT(epsilon, TensorType::NumberType())
  1779. .INPUT(grad, TensorType::NumberType())
  1780. .OUTPUT(var, TensorType::NumberType())
  1781. .ATTR(use_locking, Bool, false)
  1782. .OP_END_FACTORY_REG(ApplyAdadelta)
  1783. /**
  1784. *@brief Updates "var" according to the proximal adadelta scheme.
  1785. *@par Inputs:
  1786. *Seven inputs, including:
  1787. * @li var: A mutable Tensor of type TensorType::NumberType().
  1788. * Should be a Variable Tensor.
  1789. * @li accum: A mutable Tensor of the same type as "var".
  1790. * Should be a Variable Tensor.
  1791. * @li accum_update: A mutable Tensor of the same type as "var".
  1792. * Should be a Variable Tensor.
  1793. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1794. * @li rho: A scalar of the same type as "var", for the decay factor.
  1795. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1796. * @li grad: A Tensor of the same type as "var", for the gradient.
  1797. *@par Attributes:
  1798. *use_locking: An optional bool. Defaults to "False".
  1799. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1800. * protected by a lock; otherwise the behavior is undefined,
  1801. * but may exhibit less contention.
  1802. *@par Outputs:
  1803. *@li var: A mutable Tensor. Has the same type as "var".
  1804. *@li accum: A mutable Tensor. Has the same type as "var".
  1805. *@li accum_update: A mutable Tensor. Has the same type as "var".
  1806. *@par Third-party framework compatibility
  1807. * Compatible with the TensorFlow operator ApplyAdadelta.
  1808. */
  1809. REG_OP(ApplyAdadeltaD)
  1810. .INPUT(var, TensorType::NumberType())
  1811. .INPUT(accum, TensorType::NumberType())
  1812. .INPUT(accum_update, TensorType::NumberType())
  1813. .INPUT(lr, TensorType::NumberType())
  1814. .INPUT(rho, TensorType::NumberType())
  1815. .INPUT(epsilon, TensorType::NumberType())
  1816. .INPUT(grad, TensorType::NumberType())
  1817. .OUTPUT(var, TensorType::NumberType())
  1818. .OUTPUT(accum, TensorType::NumberType())
  1819. .OUTPUT(accum_update, TensorType::NumberType())
  1820. .ATTR(use_locking, Bool, false)
  1821. .OP_END_FACTORY_REG(ApplyAdadeltaD)
  1822. /**
  1823. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1824. * accum = accum * momentum + x1 * x2 \n
  1825. * if use_nesterov is True: \n
  1826. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1827. * else:\n
  1828. * var -= accum * lr
  1829. *
  1830. * @par Inputs:
  1831. * Six inputs, including:
  1832. * @li var: A mutable Tensor has type TensorType::NumberType().
  1833. * Should be a Variable Tensor.
  1834. * @li accum: A mutable Tensor has the same type as "var".
  1835. * Should be a Variable Tensor.
  1836. * @li lr: A scalar has the same type as "var", for the scaling factor.
  1837. * @li x1: A Tensor has type TensorType::NumberType().
  1838. * @li momentum: A scalar has the same type as "var".
  1839. * @li x2: A scalar has the same type as "var".
  1840. *
  1841. * @par Attributes:
  1842. * Two attributes, including:
  1843. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1844. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1845. * so in the end, the var you get is actually var - lr * momentum * accum.
  1846. * @li use_locking: An optional bool. Defaults to "False". \n
  1847. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1848. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1849. *
  1850. * @par Outputs:
  1851. * Two outputs, including:
  1852. * @li var: A mutable Tensor has the same type as "var".
  1853. * @li accum: A mutable Tensor has the same type as "var".
  1854. */
  1855. REG_OP(FusedMulApplyMomentum)
  1856. .INPUT(var, TensorType::NumberType())
  1857. .INPUT(accum, TensorType::NumberType())
  1858. .INPUT(lr, TensorType::NumberType())
  1859. .INPUT(x1, TensorType::NumberType())
  1860. .INPUT(momentum, TensorType::NumberType())
  1861. .INPUT(x2, TensorType::NumberType())
  1862. .OUTPUT(var, TensorType::NumberType())
  1863. .OUTPUT(accum, TensorType::NumberType())
  1864. .ATTR(use_nesterov, Bool, false)
  1865. .ATTR(use_locking, Bool, false)
  1866. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  1867. /**
  1868. * @brief Updates "var" according to the ApplyMomentum algorithm. \n
  1869. * accum = accum * momentum + x1 * x2 \n
  1870. * if use_nesterov is True: \n
  1871. * var -= x1 * x2 * lr + accum * momentum * lr \n
  1872. * else: \n
  1873. * var -= accum * lr
  1874. *
  1875. * @par Inputs:
  1876. * Seven inputs, including:
  1877. * @li var: A mutable Tensor of type float32.
  1878. * Should be a Variable Tensor.
  1879. * @li accum: A mutable Tensor has type TensorType::NumberType().
  1880. * Should be a Variable Tensor.
  1881. * @li lr: A scalar has the same type as "accum", for the scaling factor.
  1882. * @li x1: A Tensor has the same type as "accum".
  1883. * @li momentum: A scalar has the same type as "accum".
  1884. * @li x2: A scalar has the same type as "accum".
  1885. * @li var_copy: A Tensor has type float16.
  1886. *
  1887. * @par Attributes:
  1888. * Two Attributes, including:
  1889. * @li use_nesterov: An optional bool. Defaults to "False". \n
  1890. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  1891. * so in the end, the var you get is actually var - lr * momentum * accum.
  1892. * @li use_locking: An optional bool. Defaults to "False". \n
  1893. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  1894. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  1895. *
  1896. * @par Outputs:
  1897. * Three outputs, including:
  1898. * @li var: A Tensor has the type float32.
  1899. * @li var_copy: A Tensor has the type float16.
  1900. * @li accum: A Tensor has the same type as input "accum".
  1901. */
  1902. REG_OP(FusedMulApplyMomentumExtern)
  1903. .INPUT(var, TensorType(DT_FLOAT))
  1904. .INPUT(accum, TensorType::NumberType())
  1905. .INPUT(lr, TensorType::NumberType())
  1906. .INPUT(x1, TensorType::NumberType())
  1907. .INPUT(momentum, TensorType::NumberType())
  1908. .INPUT(x2, TensorType::NumberType())
  1909. .INPUT(var_copy, TensorType(DT_FLOAT16))
  1910. .OUTPUT(var, TensorType(DT_FLOAT))
  1911. .OUTPUT(var_copy, TensorType(DT_FLOAT16))
  1912. .OUTPUT(accum, TensorType::NumberType())
  1913. .ATTR(use_nesterov, Bool, false)
  1914. .ATTR(use_locking, Bool, false)
  1915. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  1916. /**
  1917. *@brief Update "g" according to the LARS algorithm.
  1918. *@par Inputs:
  1919. *Four inputs, including:
  1920. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1921. * @li g: A Tensor of the same type and shape as "w".
  1922. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1923. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1924. *@par Attributes:
  1925. *Three Attributes, including:
  1926. * @li hyperpara: An optional float. Default value is 0.001.
  1927. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1928. * @li use_clip: An optional bool. Defaults to "False".\n
  1929. * If "True", updating learning rate.
  1930. *@par Outputs:
  1931. *g_new: Tensor of the same type as "w".
  1932. */
  1933. REG_OP(LarsV2)
  1934. .INPUT(w, TensorType(DT_FLOAT))
  1935. .INPUT(g, TensorType(DT_FLOAT))
  1936. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1937. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1938. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1939. .ATTR(hyperpara, Float, 0.001)
  1940. .ATTR(epsilon, Float, 0.00001)
  1941. .ATTR(use_clip, Bool, false)
  1942. .OP_END_FACTORY_REG(LarsV2)
  1943. /**
  1944. *@brief Update "g" according to the LARS algorithm.
  1945. *@par Inputs:
  1946. *Six inputs, including:
  1947. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  1948. * @li g: A Tensor of the same type and shape as "w".
  1949. * @li w_square_sum: A Tensor of square_sum(w), has the same type as "w", Must be a scalar.
  1950. * @li g_square_sum: A Tensor of square(g), has the same type as "w", Must be a scalar.
  1951. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  1952. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  1953. *@par Attributes:
  1954. *Three Attributes, including:
  1955. * @li hyperpara: An optional float. Default value is 0.001.
  1956. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  1957. * @li use_clip: An optional bool. Defaults to "False".\n
  1958. * If "True", updating learning rate.
  1959. *@par Outputs:
  1960. *g_new: Tensor of the same type as "w".
  1961. */
  1962. REG_OP(LarsV2Update)
  1963. .INPUT(w, TensorType(DT_FLOAT))
  1964. .INPUT(g, TensorType(DT_FLOAT))
  1965. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  1966. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  1967. .INPUT(weight_decay, TensorType(DT_FLOAT))
  1968. .INPUT(learning_rate, TensorType(DT_FLOAT))
  1969. .OUTPUT(g_new, TensorType(DT_FLOAT))
  1970. .ATTR(hyperpara, Float, 0.001)
  1971. .ATTR(epsilon, Float, 0.00001)
  1972. .ATTR(use_clip, Bool, false)
  1973. .OP_END_FACTORY_REG(LarsV2Update)
  1974. /**
  1975. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1976. * @par Inputs:
  1977. * Nine inputs, including:
  1978. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1979. * Should be a Variable Tensor.
  1980. * @li accum: A mutable Tensor of the same type as "var".
  1981. * Should be a Variable Tensor.
  1982. * @li linear: A mutable Tensor of the same type as "var".
  1983. * Should be a Variable Tensor.
  1984. * @li grad: A Tensor of the same type as "var", for the gradient.
  1985. * @li indices: A vector of indices into the first dimension of var and accum.
  1986. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1987. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1988. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1989. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1990. * @par Attributes:
  1991. * use_locking: An optional bool. Defaults to "False".
  1992. * If "True", updating of the "var" and "accum" tensors will be
  1993. * protected by a lock; otherwise the behavior is undefined,
  1994. * but may exhibit less contention.
  1995. * @par Outputs:
  1996. * var: A Tensor. Has the same type and format as input "var".
  1997. * @par Third-party framework compatibility
  1998. * Compatible with the TensorFlow operator SparseApplyFtrl.
  1999. */
  2000. REG_OP(SparseApplyFtrl)
  2001. .INPUT(var, TensorType({DT_FLOAT}))
  2002. .INPUT(accum, TensorType({DT_FLOAT}))
  2003. .INPUT(linear, TensorType({DT_FLOAT}))
  2004. .INPUT(grad, TensorType({DT_FLOAT}))
  2005. .INPUT(indices, TensorType({DT_INT32}))
  2006. .INPUT(lr, TensorType({DT_FLOAT}))
  2007. .INPUT(l1, TensorType({DT_FLOAT}))
  2008. .INPUT(l2, TensorType({DT_FLOAT}))
  2009. .INPUT(lr_power, TensorType({DT_FLOAT}))
  2010. .OUTPUT(var, TensorType({DT_FLOAT}))
  2011. .ATTR(use_locking, Bool, false)
  2012. .OP_END_FACTORY_REG(SparseApplyFtrl)
  2013. /**
  2014. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  2015. * @par Inputs:
  2016. * Five inputs, including:
  2017. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2018. * Should be a Variable Tensor.
  2019. * @li accum: A mutable Tensor of the same type as "var".
  2020. * Should be a Variable Tensor.
  2021. * @li linear: A mutable Tensor of the same type as "var".
  2022. * Should be a Variable Tensor.
  2023. * @li grad: A Tensor of the same type as "var", for the gradient.
  2024. * @li indices: A vector of indices into the first dimension of var and accum.
  2025. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2026. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2027. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2028. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2029. * @par Attributes:
  2030. * use_locking: An optional bool. Defaults to "False".
  2031. * If "True", updating of the "var" and "accum" tensors will be
  2032. * protected by a lock; otherwise the behavior is undefined,
  2033. * but may exhibit less contention.
  2034. * @par Outputs:
  2035. * @li var: A Tensor. Has the same type and format as input "var".
  2036. * @li accum: A Tensor. Has the same type and format as input "accum".
  2037. * @li linear: A Tensor. Has the same type and format as input "linear".
  2038. * @par Third-party framework compatibility
  2039. * Compatible with the TensorFlow operator SparseApplyFtrl.
  2040. */
  2041. REG_OP(SparseApplyFtrlD)
  2042. .INPUT(var, TensorType({DT_FLOAT}))
  2043. .INPUT(accum, TensorType({DT_FLOAT}))
  2044. .INPUT(linear, TensorType({DT_FLOAT}))
  2045. .INPUT(grad, TensorType({DT_FLOAT}))
  2046. .INPUT(indices, TensorType({DT_INT32}))
  2047. .OUTPUT(var, TensorType({DT_FLOAT}))
  2048. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2049. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2050. .REQUIRED_ATTR(lr, Float)
  2051. .REQUIRED_ATTR(l1, Float)
  2052. .REQUIRED_ATTR(l2, Float)
  2053. .REQUIRED_ATTR(lr_power, Float)
  2054. .ATTR(use_locking, Bool, false)
  2055. .OP_END_FACTORY_REG(SparseApplyFtrlD)
  2056. /**
  2057. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  2058. * That is for rows we have grad for, "var", "accum" and "linear" are updated.
  2059. * @par Inputs:
  2060. * Ten inputs, including:
  2061. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2062. * Should be a Variable Tensor.
  2063. * @li accum: A mutable Tensor of the same type as "var".
  2064. * Should be a Variable Tensor.
  2065. * @li linear: A mutable Tensor of the same type as "var".
  2066. * Should be a Variable Tensor.
  2067. * @li grad: A Tensor of the same type as "var", for the gradient.
  2068. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  2069. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2070. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2071. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2072. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2073. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2074. * @par Attributes:
  2075. * use_locking: An optional bool. Defaults to "False".
  2076. * If "True", updating of the "var" and "accum" tensors will be
  2077. * protected by a lock; otherwise the behavior is undefined,
  2078. * but may exhibit less contention.
  2079. * @par Outputs:
  2080. * var: A Tensor. Has the same type and format as input "var".
  2081. * @par Third-party framework compatibility
  2082. * Compatible with the TensorFlow operator SparseApplyFtrlV2.
  2083. */
  2084. REG_OP(SparseApplyFtrlV2)
  2085. .INPUT(var, TensorType({DT_FLOAT}))
  2086. .INPUT(accum, TensorType({DT_FLOAT}))
  2087. .INPUT(linear, TensorType({DT_FLOAT}))
  2088. .INPUT(grad, TensorType({DT_FLOAT}))
  2089. .INPUT(indices, TensorType({DT_INT32}))
  2090. .INPUT(lr, TensorType({DT_FLOAT}))
  2091. .INPUT(l1, TensorType({DT_FLOAT}))
  2092. .INPUT(l2, TensorType({DT_FLOAT}))
  2093. .INPUT(l2_shrinkage, TensorType({DT_FLOAT}))
  2094. .INPUT(lr_power, TensorType({DT_FLOAT}))
  2095. .OUTPUT(var, TensorType({DT_FLOAT}))
  2096. .ATTR(use_locking, Bool, false)
  2097. .OP_END_FACTORY_REG(SparseApplyFtrlV2)
  2098. /**
  2099. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  2100. * That is for rows we have grad for, "var", "accum" and "linear" are updated.
  2101. * @par Inputs:
  2102. * Five inputs, including:
  2103. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2104. * Should be a Variable Tensor.
  2105. * @li accum: A mutable Tensor of the same type as "var".
  2106. * Should be a Variable Tensor.
  2107. * @li linear: A mutable Tensor of the same type as "var".
  2108. * Should be a Variable Tensor.
  2109. * @li grad: A Tensor of the same type as "var", for the gradient.
  2110. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  2111. * @par Attributes:
  2112. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2113. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2114. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2115. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2116. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2117. * @li use_locking: An optional bool. Defaults to "False".
  2118. * If "True", updating of the "var" and "accum" tensors will be
  2119. * protected by a lock; otherwise the behavior is undefined,
  2120. * but may exhibit less contention.
  2121. * @par Outputs:
  2122. * @li var: A Tensor. Has the same type and format as input "var".
  2123. * @li accum: A Tensor. Has the same type and format as input "accum".
  2124. * @li linear: A Tensor. Has the same type and format as input "linear".
  2125. * @par Third-party framework compatibility
  2126. * Compatible with the TensorFlow operator SparseApplyFtrlV2D.
  2127. */
  2128. REG_OP(SparseApplyFtrlV2D)
  2129. .INPUT(var, TensorType({DT_FLOAT}))
  2130. .INPUT(accum, TensorType({DT_FLOAT}))
  2131. .INPUT(linear, TensorType({DT_FLOAT}))
  2132. .INPUT(grad, TensorType({DT_FLOAT}))
  2133. .INPUT(indices, TensorType({DT_INT32}))
  2134. .OUTPUT(var, TensorType({DT_FLOAT}))
  2135. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2136. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2137. .REQUIRED_ATTR(lr, Float)
  2138. .REQUIRED_ATTR(l1, Float)
  2139. .REQUIRED_ATTR(l2, Float)
  2140. .REQUIRED_ATTR(l2_shrinkage, Float)
  2141. .REQUIRED_ATTR(lr_power, Float)
  2142. .ATTR(use_locking, Bool, false)
  2143. .OP_END_FACTORY_REG(SparseApplyFtrlV2D)
  2144. /**
  2145. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2146. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  2147. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  2148. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  2149. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  2150. * var <- var - mom\n
  2151. *
  2152. * @par Inputs:
  2153. * @li var: A mutable tensor. Must be one of the data types defined in\n
  2154. * TensorType::NumberType(). Should be from a Variable().
  2155. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2156. * Variable().
  2157. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2158. * Variable().
  2159. * @li lr: A scalar. Must have the same type as "var".
  2160. * @li rho: A scalar. Must have the same type as "var".
  2161. * @li momentum: A scalar. Must have the same type as "var".
  2162. * @li epsilon: A scalar. Must have the same type as "var".
  2163. * @li grad: A tensor, specifying the gradient.
  2164. * @li indices: A vector of indices into the first dimension of "var", "mom" and "ms".
  2165. *
  2166. * @par Attributes:
  2167. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2168. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  2169. * behavior is undefined, but may exhibit less contention.
  2170. *
  2171. * @par Outputs:
  2172. * var: A mutable tensor. Has the same type as input "var".
  2173. *
  2174. * @attention Constraints:
  2175. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2176. * in iterations during which "grad" is 0.
  2177. * @li The input tensors "var", "ms", and "mom" must have the same shape.
  2178. *
  2179. * @par Third-party framework compatibility
  2180. * Compatible with the TensorFlow operator SparseApplyRMSProp.
  2181. */
  2182. REG_OP(SparseApplyRMSProp)
  2183. .INPUT(var, TensorType::NumberType())
  2184. .INPUT(ms, TensorType::NumberType())
  2185. .INPUT(mom, TensorType::NumberType())
  2186. .INPUT(lr, TensorType::NumberType())
  2187. .INPUT(rho, TensorType::NumberType())
  2188. .INPUT(momentum, TensorType::NumberType())
  2189. .INPUT(epsilon, TensorType::NumberType())
  2190. .INPUT(grad, TensorType::NumberType())
  2191. .INPUT(indices, TensorType::IndexNumberType())
  2192. .OUTPUT(var, TensorType::NumberType())
  2193. .ATTR(use_locking, Bool, false)
  2194. .OP_END_FACTORY_REG(SparseApplyRMSProp)
  2195. /**
  2196. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2197. * a const input will be considered as an attribute.\n
  2198. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  2199. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  2200. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  2201. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  2202. * var <- var - mom
  2203. *
  2204. * @par Inputs:
  2205. * @li var: A mutable tensor. Must be one of the data types defined in
  2206. * TensorType::NumberType(). Should be from a Variable().
  2207. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2208. * Variable().
  2209. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2210. * Variable().
  2211. * @li lr: A scalar. Must have the same type as "var".
  2212. * @li grad: A tensor, specifying the gradient.
  2213. *
  2214. * @par Attributes:
  2215. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2216. * updating of the "var", "ms", and "mom" tensors will be protected by a lock;
  2217. * otherwise the behavior is undefined, but may exhibit less contention.
  2218. * @li rho: A required scalar. Must have the same type as "var".
  2219. * @li momentum: A required scalar. Must have the same type as "var".
  2220. * @li epsilon: A required scalar. Must have the same type as "var".
  2221. *
  2222. * @par Outputs:
  2223. * @li var: A mutable tensor. Must have the same type as input "var".
  2224. * @li ms: A mutable tensor. Must have the same type as input "ms".
  2225. * @li mom: A mutable tensor. Must have the same type as input "mom".
  2226. *
  2227. * @attention Constraints:
  2228. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2229. * in iterations during which "grad" is 0.
  2230. * @li The input tensors "var", "ms" and "mom" must have the same shape.
  2231. */
  2232. REG_OP(SparseApplyRMSPropD)
  2233. .INPUT(var, TensorType::NumberType())
  2234. .INPUT(ms, TensorType::NumberType())
  2235. .INPUT(mom, TensorType::NumberType())
  2236. .INPUT(lr, TensorType::NumberType())
  2237. .INPUT(grad, TensorType::NumberType())
  2238. .INPUT(indices, TensorType::IndexNumberType())
  2239. .OUTPUT(var, TensorType::NumberType())
  2240. .OUTPUT(ms, TensorType::NumberType())
  2241. .OUTPUT(mom, TensorType::NumberType())
  2242. .REQUIRED_ATTR(rho, Float)
  2243. .REQUIRED_ATTR(momentum, Float)
  2244. .REQUIRED_ATTR(epsilon, Float)
  2245. .ATTR(use_locking, Bool, false)
  2246. .OP_END_FACTORY_REG(SparseApplyRMSPropD)
  2247. /**
  2248. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2249. * accum <- rho * accum + (1 - rho) * grad.square()\n
  2250. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n
  2251. * var <- var - update * lr\n
  2252. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n
  2253. *
  2254. * @par Inputs:
  2255. * @li var: A mutable tensor. Must be one of the data types defined in\n
  2256. * TensorType::NumberType(). Should be from a Variable().
  2257. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2258. * Variable().
  2259. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2260. * Variable().
  2261. * @li lr: A scalar. Must have the same type as "var".
  2262. * @li rho: A scalar. Must have the same type as "var".
  2263. * @li epsilon: A scalar. Must have the same type as "var".
  2264. * @li grad: A tensor, specifying the gradient.
  2265. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2266. *
  2267. * @par Attributes:
  2268. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2269. * the "var", "accum", and "accum_update" tensors will be protected by a lock; otherwise the
  2270. * behavior is undefined, but may exhibit less contention.
  2271. *
  2272. * @par Outputs:
  2273. * var: A mutable tensor. Has the same type as input "var".
  2274. *
  2275. * @attention Constraints:
  2276. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2277. * in iterations during which "grad" is 0.
  2278. * @li The input tensors "var", "accum", and "accum_update" must have the same shape.
  2279. *
  2280. * @par Third-party framework compatibility
  2281. * Compatible with the TensorFlow operator SparseApplyAdadelta.
  2282. */
  2283. REG_OP(SparseApplyAdadelta)
  2284. .INPUT(var, TensorType::NumberType())
  2285. .INPUT(accum, TensorType::NumberType())
  2286. .INPUT(accum_update, TensorType::NumberType())
  2287. .INPUT(lr, TensorType::NumberType())
  2288. .INPUT(rho, TensorType::NumberType())
  2289. .INPUT(epsilon, TensorType::NumberType())
  2290. .INPUT(grad, TensorType::NumberType())
  2291. .INPUT(indices, TensorType::IndexNumberType())
  2292. .OUTPUT(var, TensorType::NumberType())
  2293. .ATTR(use_locking, Bool, false)
  2294. .OP_END_FACTORY_REG(SparseApplyAdadelta)
  2295. /**
  2296. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2297. * a const input will be considered as an attribute.\n
  2298. * accum <- rho * accum + (1 - rho) * grad.square()\n
  2299. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n
  2300. * var <- var - update * lr\n
  2301. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n
  2302. *
  2303. * @par Inputs:
  2304. * @li var: A mutable tensor. Must be one of the data types defined in
  2305. * TensorType::NumberType(). Should be from a Variable().
  2306. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2307. * Variable().
  2308. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2309. * Variable().
  2310. * @li lr: A scalar. Must have the same type as "var".
  2311. * @li rho: A scalar. Must have the same type as "var".
  2312. * @li grad: A tensor, specifying the gradient.
  2313. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2314. *
  2315. * @par Attributes:
  2316. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2317. * updating of the "var", "accum", and "accum_update" tensors will be protected by a lock;
  2318. * otherwise the behavior is undefined, but may exhibit less contention.
  2319. * @li epsilon: A required scalar. Must have the same type as "var".
  2320. *
  2321. * @par Outputs:
  2322. * @li var: A mutable tensor. Must have the same type as input "var".
  2323. * @li accum: A mutable tensor. Must have the same type as input "accum".
  2324. * @li accum_update: A mutable tensor. Must have the same type as input "accum_update".
  2325. *
  2326. * @attention Constraints:
  2327. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2328. * in iterations during which "grad" is 0.
  2329. * @li The input tensors "var", "accum" and "accum_update" must have the same shape.
  2330. */
  2331. REG_OP(SparseApplyAdadeltaD)
  2332. .INPUT(var, TensorType::NumberType())
  2333. .INPUT(accum, TensorType::NumberType())
  2334. .INPUT(accum_update, TensorType::NumberType())
  2335. .INPUT(lr, TensorType::NumberType())
  2336. .INPUT(rho, TensorType::NumberType())
  2337. .INPUT(grad, TensorType::NumberType())
  2338. .INPUT(indices, TensorType::IndexNumberType())
  2339. .OUTPUT(var, TensorType::NumberType())
  2340. .OUTPUT(accum, TensorType::NumberType())
  2341. .OUTPUT(accum_update, TensorType::NumberType())
  2342. .REQUIRED_ATTR(epsilon, Float)
  2343. .ATTR(use_locking, Bool, false)
  2344. .OP_END_FACTORY_REG(SparseApplyAdadeltaD)
  2345. /**
  2346. *@brief Clean memory of workspace list.
  2347. *@par Attributes:
  2348. * @li automic_add_mem_size: sizes of workspaces.
  2349. */
  2350. REG_OP(AtomicAddrClean)
  2351. .ATTR(automic_add_mem_size, ListInt, {})
  2352. .OP_END_FACTORY_REG(AtomicAddrClean)
  2353. } // namespace ge
  2354. #endif // GE_OP_TRAINING_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示