You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond_remove_pass.cc 18 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/cond_remove_pass.h"
  17. #include "common/op/ge_op_utils.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "graph/utils/node_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. namespace {
  22. const uint32_t kConditionIndexNum = 1;
  23. const uint32_t kElseBranchIndex = 1;
  24. const uint32_t kTrueIndex = 1;
  25. const uint32_t kFalseIndex = 0;
  26. /// Extra 8 bytes store pointer of string
  27. /// Extra 8 bytes store length of string
  28. /// Extra 1 byte store '\0'
  29. const int32_t kStrHeadLen = sizeof(ge::StringHead) + 1;
  30. const int32_t kInvalidRetVal = -1;
  31. }
  32. namespace ge {
  33. Status CondRemovePass::Run(NodePtr &node) {
  34. GE_CHECK_NOTNULL(node);
  35. ComputeGraphPtr graph = nullptr;
  36. OutDataAnchorPtr cond_out_anchor = nullptr;
  37. InDataAnchorPtr cond_in_anchor = nullptr;
  38. Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor);
  39. if (ret == NOT_CHANGED) {
  40. return SUCCESS;
  41. } else if (ret != SUCCESS) {
  42. GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str());
  43. return FAILED;
  44. }
  45. int32_t cond_index = 0;
  46. GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str());
  47. bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index);
  48. if (!if_cond_const || (cond_index < 0)) {
  49. return ge::SUCCESS;
  50. }
  51. ComputeGraphPtr chosen_graph = nullptr;
  52. const std::string &node_type = node->GetType();
  53. // Keep chosen branch
  54. if (kIfOpTypes.count(node_type) != 0) {
  55. ret = GetIfChosenBranch(node, static_cast<uint32_t>(cond_index), chosen_graph);
  56. if (ret != ge::SUCCESS) {
  57. return ge::FAILED;
  58. }
  59. } else if (kCaseOpTypes.count(node_type) != 0) {
  60. ret = GetCaseChosenBranch(node, static_cast<uint32_t>(cond_index), chosen_graph);
  61. if (ret != ge::SUCCESS) {
  62. return ge::FAILED;
  63. }
  64. } else {
  65. return ge::SUCCESS;
  66. }
  67. // Remove unused link from cond->node
  68. ret = RemoveDeadCondLink(static_cast<int32_t>(IF_COND_INPUT), node);
  69. if (ret != ge::SUCCESS) {
  70. return ge::FAILED;
  71. }
  72. // Copy If/Case node's relations to the new node
  73. ret = ReplaceIfCaseNodeWithPartitioncall(node, chosen_graph);
  74. if (ret != ge::SUCCESS) {
  75. return ge::FAILED;
  76. }
  77. // Isolate and delete the old node
  78. ret = IsolateAndDeleteNode(node, std::vector<int>());
  79. return ret;
  80. }
  81. Status CondRemovePass::RemoveDeadCondLink(const int32_t index, const NodePtr &node) {
  82. const auto &in_anchor = node->GetInDataAnchor(index);
  83. const auto &peerout_anchor = in_anchor->GetPeerOutAnchor();
  84. if (GraphUtils::RemoveEdge(peerout_anchor, in_anchor) != SUCCESS) {
  85. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(out_index:%d) and op:%s(%s)(in_index:%d) failed "
  86. "when CondRemovePass %s",
  87. peerout_anchor->GetOwnerNode()->GetName().c_str(),
  88. peerout_anchor->GetOwnerNode()->GetType().c_str(), peerout_anchor->GetIdx(),
  89. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetType().c_str(),
  90. in_anchor->GetIdx(), __FUNCTION__);
  91. GELOGE(FAILED, "Remove edge from node %s index %d to node %s index %d.",
  92. peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(),
  93. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  94. return FAILED;
  95. }
  96. return SUCCESS;
  97. }
  98. Status CondRemovePass::GetCaseChosenBranch(const NodePtr &node, const uint32_t cond_index,
  99. ComputeGraphPtr &compute_graph) {
  100. uint32_t subgraph_names_size = static_cast<uint32_t>(node->GetOpDesc()->GetSubgraphInstanceNames().size());
  101. uint32_t cond_index_new = cond_index;
  102. if (subgraph_names_size == 0) {
  103. REPORT_INNER_ERROR("E19999", "subgraph size of op:%s(%s) is 0, check invavlid when CondRemovePass %s",
  104. node->GetName().c_str(), node->GetType().c_str(), __FUNCTION__);
  105. GELOGE(FAILED, "Node %s has none subgraph.", node->GetName().c_str());
  106. return ge::FAILED;
  107. }
  108. // If cond index is over the maimum subgraph number, choose the last subgraph
  109. if (cond_index >= subgraph_names_size) {
  110. cond_index_new = subgraph_names_size - 1;
  111. }
  112. const auto &chosen_branch_name = node->GetOpDesc()->GetSubgraphInstanceName(cond_index_new);
  113. if (chosen_branch_name.empty()) {
  114. REPORT_INNER_ERROR("E19999", "Get subgraph name from op:%s(%s) by index:%u failed, when CondRemovePass %s",
  115. node->GetName().c_str(), node->GetType().c_str(), cond_index_new, __FUNCTION__);
  116. GELOGE(FAILED, "Node %s has no subgraph, index is %u.", node->GetName().c_str(), cond_index_new);
  117. return ge::FAILED;
  118. }
  119. auto chosen_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph())->GetSubgraph(chosen_branch_name);
  120. compute_graph = chosen_graph;
  121. // Remove graph from node, in order for remove connection from this node to chosen branch
  122. node->GetOpDesc()->RemoveSubgraphInstanceName(chosen_branch_name);
  123. return ge::SUCCESS;
  124. }
  125. Status CondRemovePass::GetIfChosenBranch(const NodePtr &node, const uint32_t cond, ComputeGraphPtr &compute_graph) {
  126. uint32_t subgraph_names_size = static_cast<uint32_t>(node->GetOpDesc()->GetSubgraphInstanceNames().size());
  127. uint32_t cond_index_new = 0;
  128. if (subgraph_names_size == 0) {
  129. REPORT_INNER_ERROR("E19999", "subgraph size of op:%s(%s) is 0, check invavlid when CondRemovePass %s",
  130. node->GetName().c_str(), node->GetType().c_str(), __FUNCTION__);
  131. GELOGE(FAILED, "Node %s has none subgraph.", node->GetName().c_str());
  132. return ge::FAILED;
  133. }
  134. // If cond is false, else branch
  135. if (cond == 0) {
  136. cond_index_new = kElseBranchIndex;
  137. }
  138. const auto &chosen_branch_name = node->GetOpDesc()->GetSubgraphInstanceName(cond_index_new);
  139. if (chosen_branch_name.empty()) {
  140. REPORT_INNER_ERROR("E19999", "Get subgraph name from op:%s(%s) by index:%u failed, when CondRemovePass %s",
  141. node->GetName().c_str(), node->GetType().c_str(), cond_index_new, __FUNCTION__);
  142. GELOGE(FAILED, "Node %s has no subgraph, index is %u.", node->GetName().c_str(), cond_index_new);
  143. return ge::FAILED;
  144. }
  145. auto chosen_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph())->GetSubgraph(chosen_branch_name);
  146. if (chosen_graph == nullptr) {
  147. REPORT_INNER_ERROR("E19999",
  148. "Find subgraph by name:%s from node:%s(%s)'s root_graph failed, when CondRemovePass %s",
  149. chosen_branch_name.c_str(), node->GetName().c_str(), node->GetType().c_str(), __FUNCTION__);
  150. GELOGE(FAILED, "Can not find branch %s in node %s's parent graph %s.", chosen_branch_name.c_str(),
  151. node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str());
  152. return ge::FAILED;
  153. }
  154. compute_graph = chosen_graph;
  155. // Remove graph from node, in order for remove connection from this node to chosen branch
  156. node->GetOpDesc()->RemoveSubgraphInstanceName(chosen_branch_name);
  157. return ge::SUCCESS;
  158. }
  159. int32_t CondRemovePass::GetCondIndex(const ConstGeTensorPtr &tensor) {
  160. if (tensor == nullptr) {
  161. return kInvalidRetVal;
  162. }
  163. const uint8_t *data_ptr = tensor->GetData().data();
  164. size_t tensor_size = tensor->GetData().size();
  165. const auto type = tensor->GetTensorDesc().GetDataType();
  166. GELOGD("Data type is %d, tensor_size is %zu.", type, tensor_size);
  167. switch (type) {
  168. case DT_STRING:
  169. return static_cast<int32_t>(((tensor_size - kStrHeadLen) > 0) ? kTrueIndex : kFalseIndex);
  170. case DT_BOOL:
  171. return static_cast<int32_t>(*reinterpret_cast<const bool *>(data_ptr));
  172. case DT_FLOAT:
  173. return static_cast<int32_t>(*reinterpret_cast<const float *>(data_ptr));
  174. case DT_DOUBLE:
  175. return static_cast<int32_t>(*reinterpret_cast<const double *>(data_ptr));
  176. case DT_INT8:
  177. case DT_UINT8:
  178. return static_cast<int32_t>(*data_ptr);
  179. case DT_FLOAT16:
  180. case DT_INT16:
  181. case DT_UINT16:
  182. return static_cast<int32_t>(*reinterpret_cast<const int16_t *>(data_ptr));
  183. case DT_INT32:
  184. return static_cast<int32_t>(*reinterpret_cast<const int32_t *>(data_ptr));
  185. case DT_UINT32:
  186. return *reinterpret_cast<const int32_t *>(data_ptr);
  187. case DT_INT64:
  188. case DT_UINT64:
  189. return static_cast<int32_t>(*reinterpret_cast<const int64_t *>(data_ptr));
  190. default:
  191. return static_cast<int32_t>(*data_ptr);
  192. }
  193. }
  194. bool CondRemovePass::CheckIfCondConstInput(const OutDataAnchorPtr &cond_out_anchor,
  195. const InDataAnchorPtr &cond_in_anchor, int32_t &cond_index) {
  196. // if pre or next anchor is null, return
  197. CHECK_FALSE_EXEC(cond_out_anchor != nullptr, return false);
  198. CHECK_FALSE_EXEC(cond_in_anchor != nullptr, return false);
  199. const auto &out_node = cond_out_anchor->GetOwnerNode();
  200. const auto &cur_node = cond_in_anchor->GetOwnerNode();
  201. OpDescPtr op_desc = cur_node->GetOpDesc();
  202. GE_CHECK_NOTNULL_EXEC(op_desc, return false);
  203. GeTensorDesc cond_tensor = out_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(cond_out_anchor->GetIdx()));
  204. GELOGI("Check if condition is const for node %s.", op_desc->GetName().c_str());
  205. if (kConstOpTypes.count(out_node->GetOpDesc()->GetType()) == 0) {
  206. return false;
  207. }
  208. // Case node only support int32 input
  209. if ((kCaseOpTypes.count(cur_node->GetType()) != 0) && (cond_tensor.GetDataType() != DT_INT32)) {
  210. GELOGW("Check input failed, node is %s, condition datatype is %s.", op_desc->GetName().c_str(),
  211. TypeUtils::DataTypeToSerialString(cond_tensor.GetDataType()).c_str());
  212. return false;
  213. }
  214. // Get weights from peer node
  215. auto weights = OpDescUtils::GetWeights(out_node);
  216. if (weights.size() <= static_cast<size_t>(cond_out_anchor->GetIdx())) {
  217. GELOGI("Get weights of node %s out index %d, weight size %zu is not fit for data index %d.",
  218. out_node->GetName().c_str(), cond_out_anchor->GetIdx(), weights.size(), cond_out_anchor->GetIdx());
  219. return false;
  220. }
  221. ConstGeTensorPtr tensor = weights[cond_out_anchor->GetIdx()];
  222. GE_CHECK_NOTNULL_EXEC(tensor, return false);
  223. bool if_zero_dim = false;
  224. if (!cond_tensor.GetShape().IsScalar()) {
  225. for (size_t dim = 0; dim < cond_tensor.GetShape().GetDimNum(); dim++) {
  226. if (cond_tensor.GetShape().GetDim(dim) == 0) {
  227. if_zero_dim = true;
  228. break;
  229. }
  230. }
  231. // If dim num is not zero and do not has zero dim, index is 1, else index is 0
  232. cond_index = static_cast<int32_t>((cond_tensor.GetShape().GetDimNum() != 0) && !if_zero_dim);
  233. } else {
  234. // Get condition index
  235. cond_index = GetCondIndex(tensor);
  236. }
  237. GELOGD("Condition index is %d, node name is %s, anchor index is %d, dim num is %zu, zero dim flag %d", cond_index,
  238. op_desc->GetName().c_str(), cond_out_anchor->GetIdx(), cond_tensor.GetShape().GetDimNum(), if_zero_dim);
  239. return true;
  240. }
  241. Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch) {
  242. // Add compute graph to new node
  243. const auto &input_desc_size = node->GetOpDesc()->GetInputsSize();
  244. const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize();
  245. // Create subgraph opdesc & node
  246. auto partitioncall_opdesc =
  247. CreateSubgraphOpDesc(node, save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
  248. auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc);
  249. // Link node's peerout anchors to new node's inanchors
  250. for (const auto &input_anchor : node->GetAllInAnchors()) {
  251. for (const auto &peerout_anchor : input_anchor->GetPeerAnchors()) {
  252. if (GraphUtils::AddEdge(peerout_anchor, partitioncall_node->GetInAnchor(
  253. input_anchor->GetIdx() - kConditionIndexNum)) != ge::GRAPH_SUCCESS) {
  254. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(out_index:%d) and op:%s(%s)(in_index:%d) failed "
  255. "when CondRemovePass %s",
  256. peerout_anchor->GetOwnerNode()->GetName().c_str(),
  257. peerout_anchor->GetOwnerNode()->GetType().c_str(), peerout_anchor->GetIdx(),
  258. partitioncall_node->GetName().c_str(),
  259. partitioncall_node->GetType().c_str(), input_anchor->GetIdx(), __FUNCTION__);
  260. GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%zu, output num:%zu",
  261. peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(),
  262. partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_desc_size,
  263. output_desc_size);
  264. return FAILED;
  265. }
  266. }
  267. }
  268. // Remove If / Case anchor and peer in anchor
  269. // Link new node's out anchors to node's peer inanchors
  270. for (const auto &output_anchor : node->GetAllOutAnchors()) {
  271. for (const auto &peerin_anchor : output_anchor->GetPeerAnchors()) {
  272. if (GraphUtils::RemoveEdge(node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) {
  273. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(out_index:%d) and op:%s(%s)(in_index:%d) failed "
  274. "when CondRemovePass %s",
  275. node->GetName().c_str(), node->GetType().c_str(), output_anchor->GetIdx(),
  276. peerin_anchor->GetOwnerNode()->GetName().c_str(),
  277. peerin_anchor->GetOwnerNode()->GetType().c_str(), peerin_anchor->GetIdx(), __FUNCTION__);
  278. GELOGE(FAILED, "Remove edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%zu, output num:%zu",
  279. node->GetName().c_str(), output_anchor->GetIdx(), peerin_anchor->GetOwnerNode()->GetName().c_str(),
  280. peerin_anchor->GetIdx(), input_desc_size, output_desc_size);
  281. return FAILED;
  282. }
  283. if (GraphUtils::AddEdge(partitioncall_node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) !=
  284. ge::GRAPH_SUCCESS) {
  285. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(out_index:%d) and op:%s(%s)(in_index:%d) failed "
  286. "when CondRemovePass %s",
  287. partitioncall_node->GetName().c_str(),
  288. partitioncall_node->GetType().c_str(), output_anchor->GetIdx(),
  289. peerin_anchor->GetOwnerNode()->GetName().c_str(),
  290. peerin_anchor->GetOwnerNode()->GetType().c_str(), peerin_anchor->GetIdx(), __FUNCTION__);
  291. GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%zu, output num:%zu",
  292. partitioncall_node->GetName().c_str(), output_anchor->GetIdx(),
  293. peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_desc_size,
  294. output_desc_size);
  295. return FAILED;
  296. }
  297. }
  298. }
  299. // update save branch information
  300. std::map<uint32_t, uint32_t> input_mapping;
  301. uint32_t new_input_num = static_cast<uint32_t>(node->GetOpDesc()->GetAllInputsSize()) - kConditionIndexNum;
  302. for (uint32_t i = 0; i < new_input_num; i++) {
  303. // original index + 1 map to index
  304. input_mapping[i + 1] = i;
  305. }
  306. save_branch->UpdateInputMapping(input_mapping);
  307. save_branch->SetParentNode(partitioncall_node);
  308. save_branch->SetParentGraph(node->GetOwnerComputeGraph());
  309. return SUCCESS;
  310. }
  311. ///
  312. /// @brief Create op_desc for subgraph node
  313. /// @param [in] name
  314. /// @param [in] input_num
  315. /// @param [in] output_num
  316. /// @return OpDescPtr
  317. ///
  318. OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num,
  319. size_t output_num) {
  320. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  321. op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num);
  322. OpDescPtr op_desc = op_desc_builder.Build();
  323. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  324. size_t index = op_desc->GetSubgraphInstanceNames().size();
  325. op_desc->AddSubgraphName("f");
  326. op_desc->SetSubgraphInstanceName(static_cast<uint32_t>(index), name);
  327. auto node_desc = node->GetOpDesc();
  328. GE_CHECK_NOTNULL_EXEC(node_desc, return nullptr);
  329. for (size_t i = 0; i < input_num; ++i) {
  330. (void)op_desc->UpdateInputDesc(i, node_desc->GetInputDesc(i + 1));
  331. }
  332. for (size_t i = 0; i < output_num; ++i) {
  333. (void)op_desc->UpdateOutputDesc(i, node_desc->GetOutputDesc(i));
  334. }
  335. return op_desc;
  336. }
  337. ///
  338. /// @brief Get cond info for if/case node
  339. /// @param [in] node: If/Case op
  340. /// @param [out] graph: owner_graph of if node
  341. /// @param [out] cond_out_anchor: peer_cond_anchor
  342. /// @param [out] cond_in_anchor: cond_input of if
  343. /// @return Status
  344. ///
  345. Status CondRemovePass::GetCondInfoForIfCase(const NodePtr &node, ComputeGraphPtr &graph,
  346. OutDataAnchorPtr &cond_out_anchor, InDataAnchorPtr &cond_in_anchor) {
  347. GE_CHECK_NOTNULL(node);
  348. graph = node->GetOwnerComputeGraph();
  349. GE_CHECK_NOTNULL(graph);
  350. cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT);
  351. GE_CHECK_NOTNULL(cond_in_anchor);
  352. cond_out_anchor = cond_in_anchor->GetPeerOutAnchor();
  353. GE_CHECK_NOTNULL(cond_out_anchor);
  354. return SUCCESS;
  355. }
  356. Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor,
  357. InDataAnchorPtr &cond_in_anchor) {
  358. GE_CHECK_NOTNULL(node);
  359. std::string type = node->GetType();
  360. if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) {
  361. if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) {
  362. GELOGE(FAILED, "Get cond_info for if/case node failed.");
  363. return FAILED;
  364. }
  365. } else {
  366. GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str());
  367. return NOT_CHANGED;
  368. }
  369. return SUCCESS;
  370. }
  371. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示