You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

control_flow_ops.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_CONTROL_FLOW_OPS_H_
  17. #define GE_CONTROL_FLOW_OPS_H_
  18. #include "graph/operator_reg.h"
  19. #include "graph/operator.h"
  20. namespace ge {
  21. /**
  22. *@brief Forwards the value of an available tensor from input "x" to output "y". \n
  23. * Merge waits for at least one of the input tensors to become available. \n
  24. * It is usually combined with Switch to implement branching. \n
  25. * Merge forwards the first tensor to become available to output "y", \n
  26. * and sets "value_index" the index of the tensor in inputs.
  27. *@par Inputs:
  28. *x: The input tensors, one of which will become available. \n
  29. * Must be one of the following types: float16, float32, float64, int8, \n
  30. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  31. *@par Outputs:
  32. *@li y: The available tensor. Has the same type as "x".
  33. *@li value_index: A scalar of type int32, for the index of the chosen input \n
  34. * tensor.
  35. *@see Switch()
  36. *@par Third-party framework compatibility
  37. *@Compatible with the TensorFlow operator Merge.
  38. */
  39. REG_OP(Merge)
  40. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  41. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  42. DT_UINT64, DT_BOOL}))
  43. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  44. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  45. DT_UINT64, DT_BOOL}))
  46. .OUTPUT(value_index, TensorType({DT_INT32}))
  47. .OP_END_FACTORY_REG(Merge)
  48. /**
  49. *@brief Forwards the value of an available tensor from input "x" to output "y". \n
  50. * Merge waits for at least one of the input tensors to become available. \n
  51. * It is usually combined with Switch to implement branching. \n
  52. * Merge forwards the first tensor to become available to output "y", \n
  53. * and sets "value_index" the index of the tensor in inputs.
  54. *@par Inputs:
  55. *x: The input tensors, one of which will become available. \n
  56. * Must be one of the following types: float16, float32, float64, int8, \n
  57. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  58. *@par Outputs:
  59. *@li y: The available tensor. Has the same type as "x".
  60. *@li value_index: A scalar of type int32, for the index of the chosen input \n
  61. * tensor.
  62. *@see Switch() | Merge()
  63. *@par Third-party framework compatibility
  64. *@Compatible with the TensorFlow operator RefMerge.
  65. */
  66. REG_OP(RefMerge)
  67. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  68. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  69. DT_UINT64, DT_BOOL}))
  70. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  71. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  72. DT_UINT64, DT_BOOL}))
  73. .OUTPUT(value_index, TensorType({DT_INT32}))
  74. .OP_END_FACTORY_REG(RefMerge)
  75. /**
  76. *@brief Forwards "data" to the output port determined by "pred". \n
  77. * If "pred" is "true", the data input is forwarded to "output_true". \n
  78. * Otherwise, the data is forwarded to "output_false".
  79. *@par Inputs:
  80. *@li data: The tensor to be forwarded. \ n
  81. * Must be one of the following types: float16, float32, float64, \n
  82. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  83. *@li pred: A boolean scalar. The output port that will receive data.
  84. *@par Outputs:
  85. *@li output_false: If "pred" is "false", data will be forwarded to this output. \n
  86. * Has the same type as "data".
  87. *@li output_true: If "pred" is "true", data will be forwarded to this output. \n
  88. * Has the same type as "data".
  89. *@see Merge()
  90. *@par Third-party framework compatibility
  91. *@Compatible with the TensorFlow operator Switch.
  92. */
  93. REG_OP(Switch)
  94. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  95. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  96. DT_UINT64, DT_BOOL}))
  97. .INPUT(pred, TensorType({DT_BOOL}))
  98. .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  99. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  100. DT_UINT64, DT_BOOL}))
  101. .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  102. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  103. DT_UINT64, DT_BOOL}))
  104. .OP_END_FACTORY_REG(Switch)
  105. /**
  106. *@brief Forwards "data" to the output port determined by "pred". \n
  107. * If "pred" is "true", the data input is forwarded to "output_true". \n
  108. * Otherwise, the data is forwarded to "output_false".
  109. *@par Inputs:
  110. *@li data: The ref tensor to be forwarded. \n
  111. * Must be one of the following types: float16, float32, float64, \n
  112. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  113. *@li pred: A boolean scalar. The output port that will receive data.
  114. *@par Outputs:
  115. *@li output_false: If "pred" is "false", data will be forwarded to this output. \n
  116. * Has the same type as "data".
  117. *@li output_true: If "pred" is "true", data will be forwarded to this output. \n
  118. * Has the same type as "data".
  119. *@see Merge() | Switch()
  120. *@par Third-party framework compatibility
  121. *@Compatible with the TensorFlow operator RefSwitch.
  122. */
  123. REG_OP(RefSwitch)
  124. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  125. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  126. DT_UINT64, DT_BOOL}))
  127. .INPUT(pred, TensorType({DT_BOOL}))
  128. .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  129. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  130. DT_UINT64, DT_BOOL}))
  131. .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  132. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  133. DT_UINT64, DT_BOOL}))
  134. .OP_END_FACTORY_REG(RefSwitch)
  135. /**
  136. *@brief Forwards "data" to the output port determined by "pred_value".
  137. *@par Inputs:
  138. *@li data: The tensor to be forwarded. \ n
  139. * Must be one of the following types: float16, float32, float64, \n
  140. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  141. *@li pred_value: A int64 tensor which determines the output port that will receive data.
  142. *@par Outputs:
  143. *output: The output tensors, one of which will become available. \n
  144. * Has the same type as "data".
  145. */
  146. REG_OP(SwitchN)
  147. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  148. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  149. DT_UINT64, DT_BOOL}))
  150. .INPUT(pred_value, TensorType({DT_INT64}))
  151. .DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  152. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  153. DT_UINT64, DT_BOOL}))
  154. .OP_END_FACTORY_REG(SwitchN)
  155. /**
  156. *@brief Creates or finds a child frame, and makes "x" available to the child \n
  157. * frame. This op is used together with Exit to create loops in the graph. \n
  158. * The Executor uses the unique "frame_name" to identify frames. \n
  159. * If "is_constant" is "true", output "y" is a constant in the child \n
  160. * frame; otherwise it may be changed in the child frame.
  161. *@par Inputs:
  162. *x: The tensor to be made available to the child frame. \n
  163. * Must be one of the following types: float16, float32, float64, int8, \n
  164. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  165. *@par Attributes:
  166. *@li frame_name: A required string. The name of the child frame.
  167. *@li is_constant: A required bool. If true, the output is constant in \n
  168. * the child frame.
  169. *@par Outputs:
  170. *y: A Tensor. Has the same type as "x".
  171. *@see Exit()
  172. *@par Third-party framework compatibility
  173. *@Compatible with the TensorFlow operator Enter.
  174. */
  175. REG_OP(Enter)
  176. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  177. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  178. DT_UINT64, DT_BOOL}))
  179. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  180. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  181. DT_UINT64, DT_BOOL}))
  182. .REQUIRED_ATTR(frame_name, String)
  183. .REQUIRED_ATTR(is_constant, Bool)
  184. .OP_END_FACTORY_REG(Enter)
  185. /**
  186. *@brief Creates or finds a child frame, and makes "x" available to the child \n
  187. * frame. This op is used together with Exit to create loops in the graph. \n
  188. * The Executor uses the unique "frame_name" to identify frames. \n
  189. * If "is_constant" is "true", output "y" is a constant in the child \n
  190. * frame; otherwise it may be changed in the child frame.
  191. *@par Inputs:
  192. *x: The tensor to be made available to the child frame. \n
  193. * Must be one of the following types: float16, float32, float64, int8, \n
  194. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  195. *@par Attributes:
  196. *@li frame_name: A required string. The name of the child frame.
  197. *@li is_constant: A required bool. If true, the output is constant in \n
  198. * the child frame.
  199. *@par Outputs:
  200. *y: A tensor. Has the same type as "x".
  201. *@see Exit() | Enter()
  202. *@par Third-party framework compatibility
  203. *@Compatible with the TensorFlow operator RefEnter.
  204. */
  205. REG_OP(RefEnter)
  206. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  207. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  208. DT_UINT64, DT_BOOL}))
  209. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  210. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  211. DT_UINT64, DT_BOOL}))
  212. .REQUIRED_ATTR(frame_name, String)
  213. .REQUIRED_ATTR(is_constant, Bool)
  214. .OP_END_FACTORY_REG(RefEnter)
  215. /**
  216. *@brief Forwards the input to the output. This op represents the loop \n
  217. * termination condition.
  218. *@par Inputs:
  219. *x: A boolean scalar. The condition of the Switch op.
  220. *@par Outputs:
  221. *y: The tensor "x".
  222. *@see Switch()
  223. *@par Third-party framework compatibility
  224. *@Compatible with the TensorFlow operator LoopCond.
  225. */
  226. REG_OP(LoopCond)
  227. .INPUT(x, TensorType({DT_BOOL}))
  228. .OUTPUT(y, TensorType({DT_BOOL}))
  229. .OP_END_FACTORY_REG(LoopCond)
  230. /**
  231. *@brief Makes the input available to the next iteration.
  232. *@par Inputs:
  233. *x: The tensor to be made available to the next iteration. \n
  234. * Must be one of the following types: float16, float32, float64, int8, \n
  235. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  236. *@par Outputs:
  237. *y: A Tensor. Has the same type as "x".
  238. *@par Third-party framework compatibility
  239. *@Compatible with the TensorFlow operator NextIteration.
  240. */
  241. REG_OP(NextIteration)
  242. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  243. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  244. DT_UINT64, DT_BOOL}))
  245. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  246. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  247. DT_UINT64, DT_BOOL}))
  248. .OP_END_FACTORY_REG(NextIteration)
  249. /**
  250. *@brief Makes the input available to the next iteration.
  251. *@par Inputs:
  252. *x: The tensor to be made available to the next iteration. \n
  253. * Must be one of the following types: float16, float32, float64, int8, \n
  254. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  255. *@par Outputs:
  256. *y: A tensor. Has the same type as "x".
  257. *@par Third-party framework compatibility
  258. *@Compatible with the TensorFlow operator RefNextIteration.
  259. */
  260. REG_OP(RefNextIteration)
  261. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  262. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  263. DT_UINT64, DT_BOOL}))
  264. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  265. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  266. DT_UINT64, DT_BOOL}))
  267. .OP_END_FACTORY_REG(RefNextIteration)
  268. /**
  269. *@brief Exits the current frame to its parent frame.
  270. *@par Inputs:
  271. *x: The tensor to be made available to the parent frame. \n
  272. * Must be one of the following types: float16, float32, float64, int8, \n
  273. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  274. *@par Outputs:
  275. *y: A Tensor. Has the same type as "x".
  276. *@see Enter()
  277. *@par Third-party framework compatibility
  278. *@Compatible with the TensorFlow operator Exit.
  279. */
  280. REG_OP(Exit)
  281. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  282. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  283. DT_UINT64, DT_BOOL}))
  284. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  285. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  286. DT_UINT64, DT_BOOL}))
  287. .OP_END_FACTORY_REG(Exit)
  288. /**
  289. *@brief Exits the current frame to its parent frame.
  290. *@par Inputs:
  291. *x: The tensor to be made available to the parent frame. \n
  292. * Must be one of the following types: float16, float32, float64, int8, \n
  293. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  294. *@par Outputs:
  295. *y: A tensor. Has the same type as "x".
  296. *@see Enter() | Exit()
  297. *@par Third-party framework compatibility
  298. *@Compatible with the TensorFlow operator RefExit.
  299. */
  300. REG_OP(RefExit)
  301. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  302. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  303. DT_UINT64, DT_BOOL}))
  304. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  305. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  306. DT_UINT64, DT_BOOL}))
  307. .OP_END_FACTORY_REG(RefExit)
  308. /**
  309. *@brief Only useful as a placeholder for control edges. \n
  310. * It is similar to a no-op that always produces a live control output \n
  311. * even when some control inputs are dead.
  312. *@par Third-party framework compatibility
  313. *@Compatible with the TensorFlow operator ControlTrigger.
  314. */
  315. REG_OP(ControlTrigger)
  316. .OP_END_FACTORY_REG(ControlTrigger)
  317. } // namespace ge
  318. #endif // GE_CONTROL_FLOW_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示