You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "infer_base_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/formats/utils/formats_trans_utils.h"
  19. #include "common/util/error_manager/error_manager.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "framework/common/util.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/node_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. graphStatus FindValidSubgraphNetoutput(const ConstNodePtr &node, const ComputeGraphPtr &sub_graph, NodePtr &netoutput) {
  30. auto sub_nodes = sub_graph->GetDirectNode();
  31. for (size_t i = sub_nodes.size(); i > 0; --i) {
  32. auto sub_node = sub_nodes.at(i - 1);
  33. if (sub_node->GetType() == NETOUTPUT) {
  34. if (sub_node == nullptr) {
  35. REPORT_INNER_ERROR("E19999", "NetOutput node is null in subgraph %s, parent node %s.",
  36. sub_graph->GetName().c_str(), node->GetName().c_str());
  37. GELOGE(GRAPH_FAILED, "[Check][Param] NetOutput node is null on sub graph %s, parent node %s",
  38. sub_graph->GetName().c_str(), node->GetName().c_str());
  39. return GRAPH_FAILED;
  40. }
  41. auto sub_node_opdesc = sub_node->GetOpDesc();
  42. if (sub_node_opdesc == nullptr) {
  43. REPORT_INNER_ERROR("E19999", "Invalid NetOutput node in subgraph %s, parent node %s, no OpDesc on it",
  44. sub_graph->GetName().c_str(), node->GetName().c_str());
  45. GELOGE(GRAPH_FAILED, "[Check][Param] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
  46. sub_graph->GetName().c_str(), node->GetName().c_str());
  47. return GRAPH_FAILED;
  48. }
  49. netoutput = sub_node;
  50. return GRAPH_SUCCESS;
  51. }
  52. }
  53. REPORT_INNER_ERROR("E19999", "Can not find the NetOutput node in subgraph %s, parent node %s",
  54. sub_graph->GetName().c_str(), node->GetName().c_str());
  55. GELOGE(GRAPH_FAILED, "[Check][Param] Can not find the NetOutput node in subgraph %s, parent node %s",
  56. sub_graph->GetName().c_str(), node->GetName().c_str());
  57. return GRAPH_FAILED;
  58. }
  59. } // namespace
  60. Status InferBasePass::Run(NodePtr &node) {
  61. GE_CHECK_NOTNULL(node);
  62. GE_CHECK_NOTNULL(node->GetOpDesc());
  63. bool need_infer = NeedInfer(node);
  64. if (!need_infer) {
  65. GELOGD("Node %s does not need to infer.", node->GetName().c_str());
  66. return SUCCESS;
  67. }
  68. std::set<NodePtr> changed_nodes;
  69. auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
  70. if (ret != GRAPH_SUCCESS) {
  71. GELOGE(ret, "Infer and update for node %s failed! ret: %u", node->GetName().c_str(), ret);
  72. return GRAPH_FAILED;
  73. }
  74. AddChangedNodesImmediateRepass(changed_nodes);
  75. return SUCCESS;
  76. }
  77. bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; }
  78. void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) {
  79. for (const auto &node_ele : changed_nodes) {
  80. AddImmediateRePassNode(node_ele);
  81. }
  82. }
  83. graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
  84. graphStatus ret;
  85. if (ContainsSubgraph(node)) {
  86. if (before_subgraph) {
  87. ret = UpdateTensorDescToSubgraphData(node);
  88. } else {
  89. ret = UpdateTensorDescToParentNodeOutput(node);
  90. }
  91. if (ret != GRAPH_SUCCESS) {
  92. GELOGE(ret, "Update tensor desc failed between parent node %s and subgraphs. ret: %u", node->GetName().c_str(),
  93. ret);
  94. return ret;
  95. }
  96. }
  97. PrintInOutTensors(node, "before_infer");
  98. ret = Infer(node);
  99. PrintInOutTensors(node, "after_infer");
  100. if (ret == GRAPH_NODE_NEED_REPASS) {
  101. // if a node need re_pass, it is not necessary to update peer node input.
  102. changed_nodes.insert(node);
  103. return GRAPH_SUCCESS;
  104. } else if (ret != GRAPH_SUCCESS && ret != GRAPH_NOT_CHANGED) {
  105. GELOGE(ret, "Infer failed for node %s, ret: %u", node->GetName().c_str(), ret);
  106. return ret;
  107. }
  108. ret = UpdateTensorDescToPeerInputs(node, changed_nodes);
  109. if (ret != GRAPH_SUCCESS) {
  110. GELOGE(ret, "Node %s updates tensor desc to peer input nodes failed! ret: %u", node->GetName().c_str(), ret);
  111. }
  112. GELOGD("Node %s infer and update succeeded .", node->GetName().c_str());
  113. return ret;
  114. }
  115. bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
  116. auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames();
  117. return !sub_graph_names.empty();
  118. }
  119. graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  120. auto op_desc = node->GetOpDesc();
  121. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  122. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  123. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  124. auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc();
  125. if (peer_anchor_opdesc == nullptr) {
  126. continue;
  127. }
  128. auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx());
  129. if (peer_input_desc == nullptr) {
  130. continue;
  131. }
  132. bool changed = false;
  133. auto ret = UpdateTensorDesc(output_tensor, peer_input_desc, changed);
  134. if (ret != GRAPH_SUCCESS) {
  135. REPORT_CALL_ERROR("E19999", "Update peer input desc failed, node %s.", node->GetName().c_str());
  136. GELOGE(ret, "Update peer input desc failed, node %s.", node->GetName().c_str());
  137. return ret;
  138. }
  139. if (changed) {
  140. changed_nodes.insert(peer_anchor->GetOwnerNode());
  141. GELOGD("Node %s update peer node succeeded, peer node %s is changed.", node->GetName().c_str(),
  142. peer_anchor->GetOwnerNode()->GetName().c_str());
  143. }
  144. }
  145. }
  146. return GRAPH_SUCCESS;
  147. }
  148. std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) {
  149. std::vector<ComputeGraphPtr> cur_node_subgraph;
  150. auto op_desc = node->GetOpDesc();
  151. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  152. if (sub_graph_names.empty()) {
  153. return cur_node_subgraph;
  154. }
  155. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  156. for (const auto &name : sub_graph_names) {
  157. if (name.empty()) {
  158. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  159. continue;
  160. }
  161. auto sub_graph = root_graph->GetSubgraph(name);
  162. if (sub_graph == nullptr) {
  163. GELOGW("The subgrpah %s for node %s is null.", name.c_str(), node->GetName().c_str());
  164. continue;
  165. }
  166. cur_node_subgraph.emplace_back(sub_graph);
  167. }
  168. return cur_node_subgraph;
  169. }
  170. graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node) {
  171. auto op_desc = node->GetOpDesc();
  172. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  173. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  174. if (node_sub->GetType() != DATA) {
  175. continue;
  176. }
  177. auto data_opdesc = node_sub->GetOpDesc();
  178. if (data_opdesc == nullptr) {
  179. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc",
  180. sub_graph->GetName().c_str(), node->GetName().c_str());
  181. GELOGE(GRAPH_FAILED, "[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc",
  182. sub_graph->GetName().c_str(), node->GetName().c_str());
  183. return GRAPH_FAILED;
  184. }
  185. int ref_i;
  186. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  187. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  188. sub_graph->GetName().c_str(), node->GetName().c_str());
  189. GELOGE(GRAPH_FAILED, "[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  190. sub_graph->GetName().c_str(), node->GetName().c_str());
  191. return GRAPH_FAILED;
  192. }
  193. GELOGD("Subgraph Data node ref_index is %d, parent node is %s.", ref_i, node->GetName().c_str());
  194. // In multi-batch, data shape of subgraph is different, no need to refresh.
  195. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  196. GELOGD("While updating subgraph data node, ignore node %s which is created by multi-dims",
  197. data_opdesc->GetName().c_str());
  198. continue;
  199. }
  200. auto input_desc = op_desc->MutableInputDesc(ref_i);
  201. if (input_desc == nullptr) {
  202. REPORT_INNER_ERROR("E19999",
  203. "The ref index(%d) on the data %s on the sub graph %s "
  204. "parent node %s are incompatible, inputs num %u",
  205. ref_i, node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str(),
  206. node->GetAllInDataAnchorsSize());
  207. GELOGE(GRAPH_FAILED,
  208. "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
  209. "parent node %s are incompatible, inputs num %u",
  210. ref_i, node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str(),
  211. node->GetAllInDataAnchorsSize());
  212. return GRAPH_FAILED;
  213. }
  214. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  215. node->GetName().c_str());
  216. bool has_tensor_desc_changed = false;
  217. auto data_input_td = data_opdesc->MutableInputDesc(0);
  218. auto ret = UpdateTensorDesc(input_desc, data_input_td, has_tensor_desc_changed);
  219. if (ret != GRAPH_SUCCESS) {
  220. REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
  221. node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str());
  222. GELOGE(GRAPH_FAILED, "[Update][InputDesc] of data %s on the sub graph %s parent node %s failed",
  223. node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str());
  224. return ret;
  225. }
  226. auto data_output_td = data_opdesc->MutableOutputDesc(0);
  227. ret = UpdateTensorDesc(input_desc, data_output_td, has_tensor_desc_changed);
  228. if (ret != GRAPH_SUCCESS) {
  229. REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
  230. node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str());
  231. GELOGE(GRAPH_FAILED, "[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
  232. node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str());
  233. return ret;
  234. }
  235. GELOGD("Parent node %s update subgraph data %s input and output succeed.", node->GetName().c_str(),
  236. data_opdesc->GetName().c_str());
  237. }
  238. }
  239. return GRAPH_SUCCESS;
  240. }
  241. graphStatus InferBasePass::UpdateTensorDescToParentNodeOutput(NodePtr &node) {
  242. std::vector<std::vector<GeTensorDescPtr>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  243. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  244. NodePtr netoutput;
  245. auto ret = FindValidSubgraphNetoutput(node, sub_graph, netoutput);
  246. if (ret != GRAPH_SUCCESS) {
  247. return ret;
  248. }
  249. auto netoutput_opdesc = netoutput->GetOpDesc();
  250. for (auto &netoutput_in_anchor : netoutput->GetAllInDataAnchors()) {
  251. auto netoutput_in_desc = netoutput_opdesc->MutableInputDesc(netoutput_in_anchor->GetIdx());
  252. if (netoutput_in_desc == nullptr) {
  253. REPORT_INNER_ERROR("E19999",
  254. "Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  255. sub_graph->GetName().c_str(), node->GetName().c_str(), netoutput_in_anchor->GetIdx());
  256. GELOGE(GRAPH_FAILED,
  257. "[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  258. sub_graph->GetName().c_str(), node->GetName().c_str(), netoutput_in_anchor->GetIdx());
  259. return GRAPH_FAILED;
  260. }
  261. GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", netoutput_in_anchor->GetIdx(),
  262. netoutput_in_desc->GetShape().GetDimNum());
  263. int ref_i;
  264. if (!AttrUtils::GetInt(netoutput_in_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  265. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  266. continue;
  267. }
  268. GELOGI("Parent node index of edge desc is %d", ref_i);
  269. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  270. REPORT_INNER_ERROR("E19999",
  271. "Invalid ref_index %d of parent node %s, ref_index should less than %u.", ref_i,
  272. node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
  273. GELOGE(GRAPH_FAILED,
  274. "[Get][Ref_index] Invalid ref_index %d of parent node %s, ref_index should less than %u.", ref_i,
  275. node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
  276. return GRAPH_FAILED;
  277. }
  278. ref_out_tensors[ref_i].emplace_back(netoutput_in_desc);
  279. }
  280. }
  281. return UpdateParentNodeContainsSubgraphs(node, ref_out_tensors);
  282. }
  283. graphStatus InferBasePass::UpdateParentNodeContainsSubgraphs(
  284. NodePtr &node, const std::vector<std::vector<GeTensorDescPtr>> &ref_out_tensors) {
  285. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  286. if (ref_out_tensors[i].empty()) {
  287. REPORT_CALL_ERROR("E19999", "Parent node %s ref_index %zu subgraph output tensor list is empty.",
  288. node->GetName().c_str(), i);
  289. GELOGE(GRAPH_FAILED, "[Param][check] Parent node %s ref_index %zu subgraph output tensor list is empty.",
  290. node->GetName().c_str(), i);
  291. return GRAPH_FAILED;
  292. }
  293. auto node_op_desc = node->GetOpDesc();
  294. auto node_output_td = node_op_desc->MutableOutputDesc(i);
  295. if (node_output_td == nullptr) {
  296. REPORT_CALL_ERROR("E19999", "Node %s output %zu tensor desc is null.", node->GetName().c_str(), i);
  297. GELOGE(GRAPH_FAILED, "[Param][check] Node %s output %zu tensor desc is null.", node->GetName().c_str(), i);
  298. return GRAPH_FAILED;
  299. }
  300. graphStatus ret;
  301. if (node_op_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  302. ret = UpdateOutputFromSubgraphsForMultiDims(ref_out_tensors[i], node_output_td);
  303. } else {
  304. ret = UpdateOutputFromSubgraphs(ref_out_tensors[i], node_output_td);
  305. }
  306. if (ret != GRAPH_SUCCESS) {
  307. REPORT_CALL_ERROR("E19999", "Node %s update output %zu tensor desc failed. ret: %u", node->GetName().c_str(), i,
  308. ret);
  309. GELOGE(GRAPH_FAILED, "[Param][check] Node %s update output %zu tensor desc failed. ret: %u",
  310. node->GetName().c_str(), i, ret);
  311. return ret;
  312. }
  313. GELOGD("Parent node %s successfully updated the output tensors from subgraphs.", node->GetName().c_str());
  314. }
  315. return GRAPH_SUCCESS;
  316. }
  317. void InferBasePass::PrintInOutTensors(const NodePtr &node, const std::string &phase) {
  318. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  319. return;
  320. }
  321. if (node == nullptr) {
  322. REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
  323. GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
  324. return;
  325. }
  326. ge::OpDescPtr op_desc = node->GetOpDesc();
  327. GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "Node has no opdesc, check invalid");
  328. GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
  329. std::stringstream ss;
  330. ss << "{";
  331. int32_t in_idx = 0;
  332. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  333. if (input_desc == nullptr) {
  334. in_idx++;
  335. continue;
  336. }
  337. if (in_idx > 0) {
  338. ss << " ";
  339. }
  340. ss << "input_" << in_idx << " tensor: ";
  341. ss << SerialTensorInfo(input_desc);
  342. in_idx++;
  343. }
  344. int32_t out_idx = 0;
  345. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  346. if (output_desc == nullptr) {
  347. out_idx++;
  348. continue;
  349. }
  350. ss << " ";
  351. ss << "output_" << out_idx << " tensor: ";
  352. ss << SerialTensorInfo(output_desc);
  353. out_idx++;
  354. }
  355. ss << "}";
  356. GELOGD("Infer tensor dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
  357. }
  358. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示