You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond_pass.cc 16 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/cond_pass.h"
  17. #include "common/op/ge_op_utils.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "graph/utils/type_utils.h"
  20. #include "graph/utils/node_utils.h"
  21. namespace {
  22. const std::string kStringLength = "StringLength";
  23. }
  24. namespace ge {
  25. Status CondPass::Run(NodePtr &node) {
  26. ComputeGraphPtr graph = nullptr;
  27. OutDataAnchorPtr peer_out_anchor = nullptr;
  28. InDataAnchorPtr cond_in_anchor = nullptr;
  29. Status ret = GetCondInfo(node, graph, peer_out_anchor, cond_in_anchor);
  30. if (ret == NOT_CHANGED) {
  31. return SUCCESS;
  32. } else if (ret != SUCCESS) {
  33. GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str());
  34. return FAILED;
  35. }
  36. /// cond
  37. /// 1. NonScalar: cond->Size(int32)->If / NetOutput(while)
  38. /// 2. String Scalar: cond->StringLength(int32)->If / NetOutput(while)
  39. /// 3. bool / float / double / uint8 / int16 / int8 / int64 Scalar: cond->Cast(2int32)->If / NetOutput(while)
  40. /// 4. Int32 Scalar: cond->If / NetOutput(while)
  41. OpDescPtr op_desc = cond_in_anchor->GetOwnerNode()->GetOpDesc();
  42. GE_CHECK_NOTNULL(op_desc);
  43. GELOGI("Handle cond for node %s.", op_desc->GetName().c_str());
  44. GeTensorDesc cond_tensor = op_desc->GetInputDesc(cond_in_anchor->GetIdx());
  45. if (cond_tensor.MutableShape().GetDim(0) == UNKNOWN_DIM_NUM) {
  46. GELOGI("Output tensor rank of Cond is unknown.");
  47. if (cond_tensor.GetDataType() == DT_STRING) {
  48. GE_CHK_STATUS_RET(HandleStringCond(graph, peer_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.",
  49. op_desc->GetName().c_str())
  50. }
  51. return SUCCESS;
  52. }
  53. if (!cond_tensor.GetShape().IsScalar()) {
  54. GE_CHK_STATUS_RET(HandleNonScalarCond(graph, peer_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.",
  55. op_desc->GetName().c_str())
  56. } else {
  57. switch (cond_tensor.GetDataType()) {
  58. case DT_STRING:
  59. GE_CHK_STATUS_RET(HandleStringCond(graph, peer_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.",
  60. op_desc->GetName().c_str())
  61. break;
  62. case DT_BOOL:
  63. case DT_FLOAT:
  64. case DT_DOUBLE:
  65. case DT_UINT8:
  66. case DT_INT16:
  67. case DT_INT8:
  68. case DT_INT64:
  69. GE_CHK_STATUS_RET(HandleScalarCond(graph, peer_out_anchor, cond_in_anchor, cond_tensor.GetDataType()),
  70. "HandleScalarCond for %s failed.", op_desc->GetName().c_str())
  71. break;
  72. case DT_INT32:
  73. break;
  74. default:
  75. REPORT_INNER_ERROR("E19999",
  76. "data_type:%d of index:%d input tensor in op:%s(%s) check invalid when CondPass %s",
  77. cond_tensor.GetDataType(), cond_in_anchor->GetIdx(),
  78. op_desc->GetName().c_str(), op_desc->GetType().c_str(), __FUNCTION__);
  79. GELOGE(FAILED, "UpdateInputDesc for node %s failed.", op_desc->GetName().c_str());
  80. return FAILED;
  81. }
  82. }
  83. cond_tensor.SetDataType(DT_INT32);
  84. cond_tensor.SetOriginDataType(DT_INT32);
  85. cond_tensor.SetShape(GeShape());
  86. cond_tensor.SetOriginShape(GeShape());
  87. if (op_desc->UpdateInputDesc(cond_in_anchor->GetIdx(), cond_tensor) != GRAPH_SUCCESS) {
  88. REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:%d, when CondPass %s",
  89. op_desc->GetName().c_str(), op_desc->GetType().c_str(), cond_in_anchor->GetIdx(), __FUNCTION__);
  90. GELOGE(FAILED, "UpdateInputDesc for node %s failed.", op_desc->GetName().c_str());
  91. return FAILED;
  92. }
  93. return SUCCESS;
  94. }
  95. ///
  96. /// @brief Get cond info for if / while
  97. /// @param [in] node: If / While op
  98. /// @param [out] graph: owner_graph of if node / while_cond subgraph
  99. /// @param [out] peer_out_anchor: peer_cond_anchor
  100. /// @param [out] cond_in_anchor: cond_input
  101. /// @return Status
  102. ///
  103. Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor,
  104. InDataAnchorPtr &cond_in_anchor) {
  105. GE_CHECK_NOTNULL(node);
  106. std::string type = node->GetType();
  107. if (kIfOpTypes.count(type) != 0) {
  108. if (GetCondInfoForIf(node, graph, peer_out_anchor, cond_in_anchor) != SUCCESS) {
  109. GELOGE(FAILED, "Get cond_info for if node failed.");
  110. return FAILED;
  111. }
  112. } else if (kWhileOpTypes.count(type) != 0) {
  113. if (GetCondInfoForWhile(node, graph, peer_out_anchor, cond_in_anchor) != SUCCESS) {
  114. GELOGE(FAILED, "Get cond_info for while node failed.");
  115. return FAILED;
  116. }
  117. } else {
  118. GELOGD("no need cond_pass for node %s.", node->GetName().c_str());
  119. return NOT_CHANGED;
  120. }
  121. return SUCCESS;
  122. }
  123. ///
  124. /// @brief Get cond info for if node
  125. /// @param [in] node: If op
  126. /// @param [out] graph: owner_graph of if node
  127. /// @param [out] peer_out_anchor: peer_cond_anchor
  128. /// @param [out] cond_in_anchor: cond_input of if
  129. /// @return Status
  130. ///
  131. Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor,
  132. InDataAnchorPtr &cond_in_anchor) {
  133. GE_CHECK_NOTNULL(node);
  134. graph = node->GetOwnerComputeGraph();
  135. GE_CHECK_NOTNULL(graph);
  136. cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT);
  137. GE_CHECK_NOTNULL(cond_in_anchor);
  138. peer_out_anchor = cond_in_anchor->GetPeerOutAnchor();
  139. GE_CHECK_NOTNULL(peer_out_anchor);
  140. return SUCCESS;
  141. }
  142. ///
  143. /// @brief Get cond info for while node
  144. /// @param [in] node: While op
  145. /// @param [out] graph: while_cond subgraph
  146. /// @param [out] peer_out_anchor: peer_cond_anchor
  147. /// @param [out] cond_in_anchor: input of NetOutput in cond_graph
  148. /// @return Status
  149. ///
  150. Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor,
  151. InDataAnchorPtr &cond_in_anchor) {
  152. GE_CHECK_NOTNULL(node);
  153. OpDescPtr op_desc = node->GetOpDesc();
  154. GE_CHECK_NOTNULL(op_desc);
  155. std::map<std::string, uint32_t> subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  156. auto iter = subgraph_names_to_index.find(ATTR_NAME_WHILE_COND);
  157. if (iter == subgraph_names_to_index.end()) {
  158. REPORT_INNER_ERROR("E19999", "subgraph name:%s not exist in SubgraphNameIndexes map of op:%s(%s), "
  159. "check invalid when CondPass %s", ATTR_NAME_WHILE_COND.c_str(),
  160. op_desc->GetName().c_str(), op_desc->GetType().c_str(), __FUNCTION__);
  161. GELOGE(FAILED, "Get cond_graph index failed, while_node:%s.", node->GetName().c_str());
  162. return FAILED;
  163. }
  164. std::string cond_graph_instance_name = op_desc->GetSubgraphInstanceName(iter->second);
  165. graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph())->GetSubgraph(cond_graph_instance_name);
  166. GE_CHECK_NOTNULL(graph);
  167. NodePtr net_output_node = graph->FindFirstNodeMatchType(NETOUTPUT);
  168. GE_CHECK_NOTNULL(net_output_node);
  169. // cond_graph has and only has one output
  170. uint32_t output_num = net_output_node->GetAllInDataAnchorsSize();
  171. if (output_num != 1) {
  172. REPORT_INNER_ERROR("E19999", "Input data anchor num:%u of op:%s(%s) not equal to 1, check invalid when CondPass %s",
  173. output_num, op_desc->GetName().c_str(), op_desc->GetType().c_str(), __FUNCTION__);
  174. GELOGE(FAILED, "output size of cond_graph is invalid, expect 1 but %u exactly, while_node:%s.",
  175. output_num, node->GetName().c_str());
  176. return FAILED;
  177. }
  178. cond_in_anchor = net_output_node->GetInDataAnchor(0);
  179. GE_CHECK_NOTNULL(cond_in_anchor);
  180. peer_out_anchor = cond_in_anchor->GetPeerOutAnchor();
  181. GE_CHECK_NOTNULL(peer_out_anchor);
  182. return SUCCESS;
  183. }
  184. ///
  185. /// @brief Process Cond Op with non-scalar cond_input: cond->Size->If / NetOutput(while)
  186. /// @param [in] graph
  187. /// @param [in] peer_out_anchor: peer_cond_anchor
  188. /// @param [in] cond_in_anchor: cond_input
  189. /// @return Status
  190. ///
  191. Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor,
  192. const InDataAnchorPtr &cond_in_anchor) {
  193. GELOGI("Handle cond with non-scalar cond-input.");
  194. return InsertNode(graph, peer_out_anchor, cond_in_anchor, SIZE);
  195. }
  196. ///
  197. /// @brief Process Cond Op with scalar-string cond_input: cond->StringLength(int32)->If / NetOutput(while)
  198. /// @param [in] graph
  199. /// @param [in] peer_out_anchor: peer_cond_anchor
  200. /// @param [in] cond_in_anchor: cond_input
  201. /// @return Status
  202. ///
  203. Status CondPass::HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor,
  204. const InDataAnchorPtr &cond_in_anchor) {
  205. GELOGI("Handle cond with scalar-string cond-input.");
  206. return InsertNode(graph, peer_out_anchor, cond_in_anchor, kStringLength);
  207. }
  208. ///
  209. /// @brief Process Cond Op with scalar cond_input: cond->Cast(2int32)->If / NetOutput(while)
  210. /// @param [in] graph
  211. /// @param [in] peer_out_anchor: peer_cond_anchor
  212. /// @param [in] cond_in_anchor: cond_input
  213. /// @param [in] src_type
  214. /// @return Status
  215. ///
  216. Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor,
  217. const InDataAnchorPtr &cond_in_anchor, DataType src_type) {
  218. GE_CHECK_NOTNULL(cond_in_anchor);
  219. GE_CHECK_NOTNULL(peer_out_anchor);
  220. GE_CHECK_NOTNULL(peer_out_anchor->GetOwnerNode()->GetOpDesc());
  221. GELOGI("Handle cond with scalar cond-input.");
  222. GeTensorDesc tensor = peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(peer_out_anchor->GetIdx());
  223. std::string cast_name = cond_in_anchor->GetOwnerNode()->GetName() + "_Cast";
  224. NodePtr cast_node = AddCastNode(graph, cast_name, tensor, src_type, DT_INT32);
  225. if (cast_node == nullptr) {
  226. GELOGE(FAILED, "Add Cast node failed, name:%s.", cast_name.c_str());
  227. return FAILED;
  228. }
  229. if (GraphUtils::InsertNodeAfter(peer_out_anchor, { cond_in_anchor }, cast_node) != GRAPH_SUCCESS) {
  230. REPORT_CALL_ERROR("E19999", "Insert Cast node %s(%s) between %s(%s)->%s(%s) failed, when CondPass %s",
  231. cast_node->GetName().c_str(), cast_node->GetType().c_str(),
  232. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  233. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  234. cond_in_anchor->GetOwnerNode()->GetName().c_str(),
  235. cond_in_anchor->GetOwnerNode()->GetType().c_str(), __FUNCTION__);
  236. GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.",
  237. cast_node->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  238. cond_in_anchor->GetOwnerNode()->GetName().c_str());
  239. return FAILED;
  240. }
  241. return SUCCESS;
  242. }
  243. ///
  244. /// @brief Insert node
  245. /// @param [in] graph
  246. /// @param [in] peer_out_anchor
  247. /// @param [in] in_data_anchor
  248. /// @param [in] type
  249. /// @return Status
  250. ///
  251. Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor,
  252. const InDataAnchorPtr &in_data_anchor, const std::string &type) {
  253. GE_CHECK_NOTNULL(peer_out_anchor);
  254. GE_CHECK_NOTNULL(in_data_anchor);
  255. GELOGD("Begin to insert %s node.", type.c_str());
  256. GE_CHECK_NOTNULL(peer_out_anchor->GetOwnerNode()->GetOpDesc());
  257. GE_CHECK_NOTNULL(in_data_anchor->GetOwnerNode()->GetOpDesc());
  258. GeTensorDesc in_tensor = peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(peer_out_anchor->GetIdx());
  259. GeTensorDesc out_tensor = in_data_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx());
  260. out_tensor.SetDataType(DT_INT32);
  261. out_tensor.SetOriginDataType(DT_INT32);
  262. out_tensor.SetShape(in_tensor.GetShape());
  263. out_tensor.SetOriginShape(in_tensor.GetOriginShape());
  264. OpDescBuilder op_desc_builder(in_data_anchor->GetOwnerNode()->GetName() + "_" + type, type);
  265. OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build();
  266. if (op_desc == nullptr) {
  267. REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed, when CondPass %s",
  268. (in_data_anchor->GetOwnerNode()->GetName() + "_" + type).c_str(), type.c_str(), __FUNCTION__);
  269. GELOGE(FAILED, "Create op_desc failed.");
  270. return FAILED;
  271. }
  272. NodePtr new_node = graph->AddNode(op_desc);
  273. if (new_node == nullptr) {
  274. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed when CondPass %s",
  275. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str(), __FUNCTION__);
  276. GELOGE(FAILED, "Create %s node failed.", type.c_str());
  277. return FAILED;
  278. }
  279. AddRePassNode(new_node);
  280. if (GraphUtils::InsertNodeAfter(peer_out_anchor, { in_data_anchor }, new_node) != GRAPH_SUCCESS) {
  281. REPORT_CALL_ERROR("E19999", "Insert node %s(%s) between %s(%s)->%s(%s) failed, when CondPass %s",
  282. new_node->GetName().c_str(), new_node->GetType().c_str(),
  283. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  284. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  285. in_data_anchor->GetOwnerNode()->GetName().c_str(),
  286. in_data_anchor->GetOwnerNode()->GetType().c_str(), __FUNCTION__);
  287. GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(),
  288. new_node->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  289. in_data_anchor->GetOwnerNode()->GetName().c_str());
  290. return FAILED;
  291. }
  292. return SUCCESS;
  293. }
  294. ///
  295. /// @brief Add cast node
  296. /// @param [in] graph
  297. /// @param [in] name
  298. /// @param [in] tensor
  299. /// @param [in] src
  300. /// @param [in] dst
  301. /// @return NodePtr
  302. ///
  303. NodePtr CondPass::AddCastNode(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &tensor,
  304. DataType src, DataType dst) {
  305. GELOGI("Begin to create cast op: %s, from %d to %d", name.c_str(), src, dst);
  306. GeTensorDesc in_tensor = tensor;
  307. in_tensor.SetDataType(src);
  308. in_tensor.SetOriginDataType(src);
  309. GeTensorDesc out_tensor = tensor;
  310. out_tensor.SetDataType(dst);
  311. out_tensor.SetOriginDataType(dst);
  312. OpDescBuilder op_desc_builder(name, CAST);
  313. OpDescPtr cast_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build();
  314. if (cast_desc == nullptr) {
  315. REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed, when CondPass %s",
  316. name.c_str(), CAST, __FUNCTION__);
  317. GELOGE(FAILED, "Create cast op_desc failed, name: %s.", name.c_str());
  318. return nullptr;
  319. }
  320. if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, src) &&
  321. AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, dst) &&
  322. AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, dst) &&
  323. AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) {
  324. REPORT_CALL_ERROR("E19999", "Set Attr:%s,%s,%s,%s to node:%s(%s) not all success, when CondPass %s",
  325. CAST_ATTR_SRCT.c_str(), CAST_ATTR_DSTT.c_str(),
  326. CAST_ATTR_DST_TYPE.c_str(), CAST_ATTR_TRUNCATE.c_str(),
  327. cast_desc->GetName().c_str(), cast_desc->GetType().c_str(), __FUNCTION__);
  328. GELOGE(FAILED, "Set CAST_ATTR failed, node: %s.", name.c_str());
  329. return nullptr;
  330. }
  331. NodePtr cast_node = graph->AddNode(cast_desc);
  332. if (cast_node == nullptr) {
  333. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed when CondPass %s",
  334. cast_desc->GetName().c_str(), cast_desc->GetType().c_str(), graph->GetName().c_str(),
  335. __FUNCTION__);
  336. GELOGE(FAILED, "Add cast node failed, name: %s.", name.c_str());
  337. return nullptr;
  338. }
  339. AddRePassNode(cast_node);
  340. return cast_node;
  341. }
  342. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示