You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

control_flow_ops.h 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_CONTROL_FLOW_OPS_H_
  17. #define GE_CONTROL_FLOW_OPS_H_
  18. #include "graph/operator_reg.h"
  19. #include "graph/operator.h"
  20. namespace ge {
  21. /**
  22. *@brief Forwards the value of an available tensor from input "x" to output "y". \n
  23. * Merge waits for at least one of the input tensors to become available. \n
  24. * It is usually combined with Switch to implement branching. \n
  25. * Merge forwards the first tensor to become available to output "y", \n
  26. * and sets "value_index" the index of the tensor in inputs.
  27. *@par Inputs:
  28. *x: The input tensors, one of which will become available. \n
  29. * Must be one of the following types: float16, float32, float64, int8, \n
  30. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  31. *@par Outputs:
  32. *@li y: The available tensor. Has the same type as "x".
  33. *@li value_index: A scalar of type int32, for the index of the chosen input \n
  34. * tensor.
  35. *@see Switch()
  36. *@par Third-party framework compatibility
  37. *@Compatible with the TensorFlow operator Merge.
  38. */
  39. REG_OP(Merge)
  40. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  41. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  42. DT_UINT64, DT_BOOL}))
  43. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  44. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  45. DT_UINT64, DT_BOOL}))
  46. .OUTPUT(value_index, TensorType({DT_INT32}))
  47. .OP_END_FACTORY_REG(Merge)
  48. /**
  49. *@brief Forwards the value of an available tensor from input "x" to output "y". \n
  50. * Merge waits for at least one of the input tensors to become available. \n
  51. * It is usually combined with Switch to implement branching. \n
  52. * Merge forwards the first tensor to become available to output "y", \n
  53. * and sets "value_index" the index of the tensor in inputs.
  54. *@par Inputs:
  55. *x: The input tensors, one of which will become available. \n
  56. * Must be one of the following types: float16, float32, float64, int8, \n
  57. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  58. *@par Outputs:
  59. *@li y: The available tensor. Has the same type as "x".
  60. *@li value_index: A scalar of type int32, for the index of the chosen input \n
  61. * tensor.
  62. *@see Switch() | Merge()
  63. *@par Third-party framework compatibility
  64. *@Compatible with the TensorFlow operator RefMerge.
  65. */
  66. REG_OP(RefMerge)
  67. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  68. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  69. DT_UINT64, DT_BOOL}))
  70. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  71. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  72. DT_UINT64, DT_BOOL}))
  73. .OUTPUT(value_index, TensorType({DT_INT32}))
  74. .OP_END_FACTORY_REG(RefMerge)
  75. /**
  76. *@brief Forwards "data" to the output port determined by "pred". \n
  77. * If "pred" is "true", the data input is forwarded to "output_true". \n
  78. * Otherwise, the data is forwarded to "output_false".
  79. *@par Inputs:
  80. *@li data: The tensor to be forwarded. \ n
  81. * Must be one of the following types: float16, float32, float64, \n
  82. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  83. *@li pred: A boolean scalar. The output port that will receive data.
  84. *@par Outputs:
  85. *@li output_false: If "pred" is "false", data will be forwarded to this output. \n
  86. * Has the same type as "data".
  87. *@li output_true: If "pred" is "true", data will be forwarded to this output. \n
  88. * Has the same type as "data".
  89. *@see Merge()
  90. *@par Third-party framework compatibility
  91. *@Compatible with the TensorFlow operator Switch.
  92. */
  93. REG_OP(Switch)
  94. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  95. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  96. DT_UINT64, DT_BOOL}))
  97. .INPUT(pred, TensorType({DT_BOOL}))
  98. .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  99. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  100. DT_UINT64, DT_BOOL}))
  101. .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  102. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  103. DT_UINT64, DT_BOOL}))
  104. .OP_END_FACTORY_REG(Switch)
  105. /**
  106. *@brief Forwards "data" to the output port determined by "pred". \n
  107. * If "pred" is "true", the data input is forwarded to "output_true". \n
  108. * Otherwise, the data is forwarded to "output_false".
  109. *@par Inputs:
  110. *@li data: The ref tensor to be forwarded. \n
  111. * Must be one of the following types: float16, float32, float64, \n
  112. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  113. *@li pred: A boolean scalar. The output port that will receive data.
  114. *@par Outputs:
  115. *@li output_false: If "pred" is "false", data will be forwarded to this output. \n
  116. * Has the same type as "data".
  117. *@li output_true: If "pred" is "true", data will be forwarded to this output. \n
  118. * Has the same type as "data".
  119. *@see Merge() | Switch()
  120. *@par Third-party framework compatibility
  121. *@Compatible with the TensorFlow operator RefSwitch.
  122. */
  123. REG_OP(RefSwitch)
  124. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  125. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  126. DT_UINT64, DT_BOOL}))
  127. .INPUT(pred, TensorType({DT_BOOL}))
  128. .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  129. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  130. DT_UINT64, DT_BOOL}))
  131. .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  132. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  133. DT_UINT64, DT_BOOL}))
  134. .OP_END_FACTORY_REG(RefSwitch)
  135. REG_OP(SwitchN)
  136. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  137. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  138. DT_UINT64, DT_BOOL}))
  139. .INPUT(pred_value, TensorType({DT_INT64}))
  140. .DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  141. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  142. DT_UINT64, DT_BOOL}))
  143. .OP_END_FACTORY_REG(SwitchN)
  144. /**
  145. *@brief Creates or finds a child frame, and makes "x" available to the child \n
  146. * frame. This op is used together with Exit to create loops in the graph. \n
  147. * The Executor uses the unique "frame_name" to identify frames. \n
  148. * If "is_constant" is "true", output "y" is a constant in the child \n
  149. * frame; otherwise it may be changed in the child frame.
  150. *@par Inputs:
  151. *x: The tensor to be made available to the child frame. \n
  152. * Must be one of the following types: float16, float32, float64, int8, \n
  153. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  154. *@par Attributes:
  155. *@li frame_name: A required string. The name of the child frame.
  156. *@li is_constant: A required bool. If true, the output is constant in \n
  157. * the child frame.
  158. *@par Outputs:
  159. *y: A Tensor. Has the same type as "x".
  160. *@see Exit()
  161. *@par Third-party framework compatibility
  162. *@Compatible with the TensorFlow operator Enter.
  163. */
  164. REG_OP(Enter)
  165. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  166. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  167. DT_UINT64, DT_BOOL}))
  168. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  169. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  170. DT_UINT64, DT_BOOL}))
  171. .REQUIRED_ATTR(frame_name, String)
  172. .REQUIRED_ATTR(is_constant, Bool)
  173. .OP_END_FACTORY_REG(Enter)
  174. /**
  175. *@brief Creates or finds a child frame, and makes "x" available to the child \n
  176. * frame. This op is used together with Exit to create loops in the graph. \n
  177. * The Executor uses the unique "frame_name" to identify frames. \n
  178. * If "is_constant" is "true", output "y" is a constant in the child \n
  179. * frame; otherwise it may be changed in the child frame.
  180. *@par Inputs:
  181. *x: The tensor to be made available to the child frame. \n
  182. * Must be one of the following types: float16, float32, float64, int8, \n
  183. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  184. *@par Attributes:
  185. *@li frame_name: A required string. The name of the child frame.
  186. *@li is_constant: A required bool. If true, the output is constant in \n
  187. * the child frame.
  188. *@par Outputs:
  189. *y: A tensor. Has the same type as "x".
  190. *@see Exit() | Enter()
  191. *@par Third-party framework compatibility
  192. *@Compatible with the TensorFlow operator RefEnter.
  193. */
  194. REG_OP(RefEnter)
  195. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  196. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  197. DT_UINT64, DT_BOOL}))
  198. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  199. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  200. DT_UINT64, DT_BOOL}))
  201. .REQUIRED_ATTR(frame_name, String)
  202. .REQUIRED_ATTR(is_constant, Bool)
  203. .OP_END_FACTORY_REG(RefEnter)
  204. /**
  205. *@brief Forwards the input to the output. This op represents the loop \n
  206. * termination condition.
  207. *@par Inputs:
  208. *x: A boolean scalar. The condition of the Switch op.
  209. *@par Outputs:
  210. *y: The tensor "x".
  211. *@see Switch()
  212. *@par Third-party framework compatibility
  213. *@Compatible with the TensorFlow operator LoopCond.
  214. */
  215. REG_OP(LoopCond)
  216. .INPUT(x, TensorType({DT_BOOL}))
  217. .OUTPUT(y, TensorType({DT_BOOL}))
  218. .OP_END_FACTORY_REG(LoopCond)
  219. /**
  220. *@brief Makes the input available to the next iteration.
  221. *@par Inputs:
  222. *x: The tensor to be made available to the next iteration. \n
  223. * Must be one of the following types: float16, float32, float64, int8, \n
  224. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  225. *@par Outputs:
  226. *y: A Tensor. Has the same type as "x".
  227. *@par Third-party framework compatibility
  228. *@Compatible with the TensorFlow operator NextIteration.
  229. */
  230. REG_OP(NextIteration)
  231. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  232. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  233. DT_UINT64, DT_BOOL}))
  234. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  235. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  236. DT_UINT64, DT_BOOL}))
  237. .OP_END_FACTORY_REG(NextIteration)
  238. /**
  239. *@brief Makes the input available to the next iteration.
  240. *@par Inputs:
  241. *x: The tensor to be made available to the next iteration. \n
  242. * Must be one of the following types: float16, float32, float64, int8, \n
  243. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  244. *@par Outputs:
  245. *y: A tensor. Has the same type as "x".
  246. *@par Third-party framework compatibility
  247. *@Compatible with the TensorFlow operator RefNextIteration.
  248. */
  249. REG_OP(RefNextIteration)
  250. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  251. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  252. DT_UINT64, DT_BOOL}))
  253. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  254. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  255. DT_UINT64, DT_BOOL}))
  256. .OP_END_FACTORY_REG(RefNextIteration)
  257. /**
  258. *@brief Exits the current frame to its parent frame.
  259. *@par Inputs:
  260. *x: The tensor to be made available to the parent frame. \n
  261. * Must be one of the following types: float16, float32, float64, int8, \n
  262. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  263. *@par Outputs:
  264. *y: A Tensor. Has the same type as "x".
  265. *@see Enter()
  266. *@par Third-party framework compatibility
  267. *@Compatible with the TensorFlow operator Exit.
  268. */
  269. REG_OP(Exit)
  270. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  271. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  272. DT_UINT64, DT_BOOL}))
  273. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  274. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  275. DT_UINT64, DT_BOOL}))
  276. .OP_END_FACTORY_REG(Exit)
  277. /**
  278. *@brief Exits the current frame to its parent frame.
  279. *@par Inputs:
  280. *x: The tensor to be made available to the parent frame. \n
  281. * Must be one of the following types: float16, float32, float64, int8, \n
  282. * int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  283. *@par Outputs:
  284. *y: A tensor. Has the same type as "x".
  285. *@see Enter() | Exit()
  286. *@par Third-party framework compatibility
  287. *@Compatible with the TensorFlow operator RefExit.
  288. */
  289. REG_OP(RefExit)
  290. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  291. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  292. DT_UINT64, DT_BOOL}))
  293. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  294. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  295. DT_UINT64, DT_BOOL}))
  296. .OP_END_FACTORY_REG(RefExit)
  297. /**
  298. *@brief Only useful as a placeholder for control edges. \n
  299. * It is similar to a no-op that always produces a live control output \n
  300. * even when some control inputs are dead.
  301. *@par Third-party framework compatibility
  302. *@Compatible with the TensorFlow operator ControlTrigger.
  303. */
  304. REG_OP(ControlTrigger)
  305. .OP_END_FACTORY_REG(ControlTrigger)
  306. } // namespace ge
  307. #endif // GE_CONTROL_FLOW_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示