You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "infer_base_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/util/error_manager/error_manager.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/util.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/debug/ge_util.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/node_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. string Serial(const vector<int64_t> &dims) {
  30. string serial_string;
  31. serial_string += "[";
  32. for (int64_t dim : dims) {
  33. serial_string += std::to_string(dim) + " ";
  34. }
  35. serial_string += "]";
  36. return serial_string;
  37. }
  38. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  39. desc_str += "[";
  40. std::vector<std::pair<int64_t, int64_t>> shape_range;
  41. (void)desc->GetShapeRange(shape_range);
  42. for (const auto &pair : shape_range) {
  43. desc_str += "{";
  44. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  45. desc_str += "},";
  46. }
  47. desc_str += "]";
  48. shape_range.clear();
  49. (void)desc->GetOriginShapeRange(shape_range);
  50. for (const auto &pair : shape_range) {
  51. desc_str += ",{";
  52. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  53. desc_str += "},";
  54. }
  55. }
  56. graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node,
  57. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
  58. auto sub_nodes = sub_graph->GetDirectNode();
  59. for (size_t i = sub_nodes.size(); i > 0; --i) {
  60. auto sub_node = sub_nodes.at(i - 1);
  61. if (sub_node->GetType() == NETOUTPUT) {
  62. netoutput = sub_node;
  63. }
  64. if (sub_node->GetType() == DATA) {
  65. if (sub_node->GetOpDesc() == nullptr) {
  66. return GRAPH_FAILED;
  67. }
  68. int ref_i;
  69. if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  70. REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  71. GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  72. return GRAPH_FAILED;
  73. }
  74. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
  75. REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!",
  76. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  77. GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!",
  78. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  79. return GRAPH_FAILED;
  80. }
  81. ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
  82. }
  83. }
  84. return GRAPH_SUCCESS;
  85. }
  86. } // namespace
  87. Status InferBasePass::Run(NodePtr &node) {
  88. GE_CHECK_NOTNULL(node);
  89. GE_CHECK_NOTNULL(node->GetOpDesc());
  90. bool need_infer = NeedInfer(node);
  91. if (!need_infer) {
  92. GELOGD("Node %s does not need to infer.", node->GetName().c_str());
  93. return SUCCESS;
  94. }
  95. std::set<NodePtr> changed_nodes;
  96. auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
  97. if (ret != GRAPH_SUCCESS) {
  98. (void)AnalyzeFailedInfo(node);
  99. return GE_GRAPH_INFERSHAPE_FAILED;
  100. }
  101. /*
  102. * we will use changed nodes to do repass for control_ops.
  103. * AddChangedNodesImmediateRepass(changed_nodes);
  104. */
  105. auto status = DoRepassForLoopNode(node);
  106. if (status != SUCCESS) {
  107. GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "repass failed. node: %s", node->GetName().c_str());
  108. return GE_GRAPH_INFERSHAPE_FAILED;
  109. }
  110. return SUCCESS;
  111. }
  112. bool InferBasePass::NeedInfer(const NodePtr &node) { return true; }
  113. void InferBasePass::AnalyzeFailedInfo(const NodePtr &node) { /* Analyze and select failed info*/ }
  114. Status InferBasePass::DoRepassForLoopNode(NodePtr &node) { return SUCCESS; }
  115. graphStatus InferBasePass::UpdatePeerInputs(NodePtr &node) { return GRAPH_SUCCESS; }
  116. void InferBasePass::AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes) {
  117. for (const auto &node_ele : changed_nodes) {
  118. AddImmediateRePassNode(node_ele);
  119. }
  120. }
  121. graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
  122. auto ret = GRAPH_SUCCESS;
  123. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  124. auto opdesc = node->GetOpDesc();
  125. // some op can not infershape twice such as aipp
  126. bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
  127. if (need_update_input) {
  128. ret = UpdateCurOpInputDesc(node);
  129. if (ret != GRAPH_SUCCESS) {
  130. REPORT_CALL_ERROR("E19999", "update op input_desc failed! ret:%d, node:%s", ret, node->GetName().c_str());
  131. GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%d", ret);
  132. return ret;
  133. }
  134. }
  135. bool contain_subgraph = ContainsSubgraph(node);
  136. if (contain_subgraph && before_subgraph) {
  137. ret = UpdateTensorDescToSubgraphData(node, changed_nodes);
  138. if (ret != GRAPH_SUCCESS) {
  139. return ret;
  140. }
  141. }
  142. ret = Infer(node);
  143. if (ret != GRAPH_SUCCESS) {
  144. return ret;
  145. }
  146. if (contain_subgraph && !before_subgraph) {
  147. ret = UpdateTensorDescToParentNode(node, changed_nodes);
  148. if (ret != GRAPH_SUCCESS) {
  149. return ret;
  150. }
  151. }
  152. ret = UpdatePeerInputs(node);
  153. return ret;
  154. }
  155. graphStatus InferBasePass::UpdateCurOpInputDesc(const NodePtr &node_ptr) {
  156. for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
  157. auto in_idx = in_anchor->GetIdx();
  158. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  159. if (peer_out_data_anchor == nullptr) {
  160. continue;
  161. }
  162. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  163. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  164. continue;
  165. }
  166. int peer_out_idx = peer_out_data_anchor->GetIdx();
  167. auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
  168. // check shape and dtype continuity. do not stop process
  169. auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
  170. if (in_desc == nullptr) {
  171. continue;
  172. }
  173. auto in_shape = in_desc->MutableShape().GetDims();
  174. auto in_dtype = in_desc->GetDataType();
  175. auto peer_out_shape = peer_out_desc->MutableShape().GetDims();
  176. auto peer_out_dtype = peer_out_desc->GetDataType();
  177. if (peer_out_dtype != in_dtype) {
  178. GELOGW(
  179. "current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th "
  180. "output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  181. node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
  182. peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
  183. } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
  184. string in_shape_str = " "; // Serial(in_shape);
  185. string peer_out_shape_str = " "; // Serial(peer_out_shape);
  186. GELOGW(
  187. "current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th "
  188. "output_shape is [%s].The two shape should be same! Please check graph and fix it",
  189. node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx,
  190. peer_out_shape_str.c_str());
  191. }
  192. // refresh current node input desc
  193. bool output_changed = false;
  194. (void)UpdateInputDescAttr(peer_out_desc, in_desc, output_changed);
  195. }
  196. return GRAPH_SUCCESS;
  197. }
  198. graphStatus InferBasePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  199. changed = false;
  200. return GRAPH_SUCCESS;
  201. }
  202. bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
  203. auto op_desc = node->GetOpDesc();
  204. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  205. if (sub_graph_names.empty()) {
  206. return false;
  207. }
  208. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  209. if (root_graph == nullptr) {
  210. return false;
  211. }
  212. for (const auto &name : sub_graph_names) {
  213. if (name.empty()) {
  214. continue;
  215. }
  216. auto sub_graph = root_graph->GetSubgraph(name);
  217. if (sub_graph != nullptr) {
  218. return true;
  219. }
  220. }
  221. return false;
  222. }
  223. std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) {
  224. std::vector<ComputeGraphPtr> cur_node_subgraph;
  225. auto op_desc = node->GetOpDesc();
  226. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  227. if (sub_graph_names.empty()) {
  228. return cur_node_subgraph;
  229. }
  230. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  231. for (const auto &name : sub_graph_names) {
  232. if (name.empty()) {
  233. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  234. continue;
  235. }
  236. auto sub_graph = root_graph->GetSubgraph(name);
  237. if (sub_graph == nullptr) {
  238. REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  239. GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  240. continue;
  241. }
  242. cur_node_subgraph.emplace_back(sub_graph);
  243. }
  244. return cur_node_subgraph;
  245. }
  246. graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  247. // if infer again, update output of while into subgraph data node
  248. auto op_desc = node->GetOpDesc();
  249. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  250. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  251. if (node_sub->GetType() != DATA) {
  252. continue;
  253. }
  254. auto name = sub_graph->GetName();
  255. int ref_i;
  256. auto data_opdesc = node_sub->GetOpDesc();
  257. if (data_opdesc == nullptr) {
  258. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  259. node->GetName().c_str());
  260. GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  261. node->GetName().c_str());
  262. return GRAPH_FAILED;
  263. }
  264. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  265. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  266. name.c_str(), node->GetName().c_str());
  267. GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  268. node->GetName().c_str());
  269. return GRAPH_FAILED;
  270. }
  271. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  272. continue;
  273. }
  274. auto input_desc = op_desc->MutableInputDesc(ref_i);
  275. if (input_desc == nullptr) {
  276. REPORT_INNER_ERROR("E19999",
  277. "The ref index(%d) on the data %s on the sub graph %s "
  278. "parent node %s are incompatible, inputs num %u",
  279. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  280. node->GetAllInDataAnchorsSize());
  281. GE_LOGE(
  282. "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
  283. "parent node %s are incompatible, inputs num %u",
  284. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize());
  285. return GRAPH_FAILED;
  286. }
  287. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  288. node->GetName().c_str());
  289. // if need infer again, refresh subgraph input with output
  290. bool is_infer_again = false;
  291. AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again);
  292. if (is_infer_again) {
  293. input_desc = op_desc->MutableOutputDesc(ref_i);
  294. if (input_desc == nullptr) {
  295. REPORT_INNER_ERROR("E19999",
  296. "The ref index(%d) on the data %s on the subgraph %s "
  297. "parent node %s are incompatible, outputs num %u.",
  298. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  299. node->GetAllOutDataAnchorsSize());
  300. GELOGE(PARAM_INVALID,
  301. "[Call][MutableOutputDesc] The ref index(%d) on the data %s on the subgraph %s "
  302. "parent node %s are incompatible, outputs num %u.",
  303. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  304. node->GetAllOutDataAnchorsSize());
  305. }
  306. GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]",
  307. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), ref_i,
  308. data_opdesc->GetInputDescPtr(0)->GetShape().ToString().c_str(),
  309. input_desc->GetShape().ToString().c_str());
  310. }
  311. // auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  312. bool input_changed = false;
  313. auto data_input_desc = data_opdesc->MutableInputDesc(0);
  314. auto ret = UpdateTensorDesc(input_desc, data_input_desc, input_changed);
  315. if (ret != GRAPH_SUCCESS) {
  316. REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
  317. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  318. GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(),
  319. name.c_str(), node->GetName().c_str());
  320. return ret;
  321. }
  322. // ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  323. bool output_changed = false;
  324. auto data_output_desc = data_opdesc->MutableOutputDesc(0);
  325. ret = UpdateTensorDesc(input_desc, data_output_desc, output_changed);
  326. if (ret != GRAPH_SUCCESS) {
  327. REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
  328. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  329. GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
  330. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  331. return ret;
  332. }
  333. if (input_changed || output_changed) {
  334. changed_nodes.insert(node_sub);
  335. }
  336. }
  337. }
  338. return GRAPH_SUCCESS;
  339. }
  340. graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  341. std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
  342. std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  343. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  344. auto name = sub_graph->GetName();
  345. NodePtr netoutput = nullptr;
  346. auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
  347. if (ret != GRAPH_SUCCESS) {
  348. return ret;
  349. }
  350. if (netoutput == nullptr) {
  351. REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  352. node->GetName().c_str());
  353. GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  354. node->GetName().c_str());
  355. return GRAPH_FAILED;
  356. }
  357. auto netoutput_opdesc = netoutput->GetOpDesc();
  358. if (netoutput_opdesc == nullptr) {
  359. REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
  360. name.c_str(), node->GetName().c_str());
  361. GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  362. node->GetName().c_str());
  363. return GRAPH_FAILED;
  364. }
  365. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  366. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  367. if (edge_desc == nullptr) {
  368. REPORT_INNER_ERROR("E19999",
  369. "Invalid NetOutput node on sub graph %s, parent node %s, "
  370. "can not find input tensor %d",
  371. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  372. GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  373. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  374. return GRAPH_FAILED;
  375. }
  376. GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(),
  377. edge_desc->GetShape().GetDimNum());
  378. int ref_i;
  379. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  380. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  381. continue;
  382. }
  383. GELOGI("Parent node index of edge desc is %d", ref_i);
  384. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  385. return GRAPH_FAILED;
  386. }
  387. ref_out_tensors[ref_i].emplace_back(*edge_desc);
  388. }
  389. }
  390. if (node->GetType() == WHILE) {
  391. return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes);
  392. }
  393. return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes);
  394. }
  395. graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node,
  396. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
  397. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  398. std::set<NodePtr> &changed_nodes) {
  399. GELOGD("Enter update parent node shape for class while op process");
  400. if (ref_data_tensors.size() != ref_out_tensors.size()) {
  401. REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!",
  402. node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(),
  403. ref_out_tensors.size());
  404. GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!",
  405. node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
  406. return GRAPH_FAILED;
  407. }
  408. for (size_t i = 0; i < ref_data_tensors.size(); i++) {
  409. if (ref_out_tensors[i].size() != 1) {
  410. REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!");
  411. GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!");
  412. return GRAPH_FAILED;
  413. }
  414. }
  415. bool need_infer_again = false;
  416. // check input and output
  417. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  418. if (ref_out_tensors[i].empty()) {
  419. continue;
  420. }
  421. auto ref_out_tensor = ref_out_tensors[i].at(0);
  422. auto out_shape = ref_out_tensor.MutableShape();
  423. vector<std::pair<int64_t, int64_t>> data_shape_range;
  424. // ref_i's data and output tensor shape should be same
  425. for (auto &tensor : ref_data_tensors[i]) {
  426. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  427. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output",
  428. node->GetName().c_str());
  429. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.",
  430. node->GetName().c_str());
  431. return GRAPH_FAILED;
  432. }
  433. auto data_shape = tensor.MutableShape();
  434. // input is dynamic, here use dim_num
  435. if (data_shape.GetDims() != out_shape.GetDims()) {
  436. GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.",
  437. node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str());
  438. if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
  439. ref_out_tensor.SetUnknownDimNumShape();
  440. } else {
  441. for (size_t j = 0; j < data_shape.GetDimNum(); ++j) {
  442. if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
  443. if (data_shape.GetDim(j) != UNKNOWN_DIM) {
  444. // if input data is fix shape, output is different, need_infer_again
  445. need_infer_again = true;
  446. }
  447. data_shape.SetDim(j, UNKNOWN_DIM);
  448. }
  449. // set shape rang of while, if dim is unknown ,set shape range as {1,-1}
  450. if (data_shape.GetDim(j) == UNKNOWN_DIM) {
  451. data_shape_range.emplace_back(std::make_pair(1, UNKNOWN_DIM));
  452. } else {
  453. data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)));
  454. }
  455. }
  456. ref_out_tensor.SetShape(data_shape);
  457. ref_out_tensor.SetShapeRange(data_shape_range);
  458. }
  459. }
  460. }
  461. //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  462. bool output_changed = false;
  463. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  464. (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc, output_changed);
  465. if (output_changed) {
  466. changed_nodes.insert(node);
  467. }
  468. }
  469. AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again);
  470. return GRAPH_SUCCESS;
  471. }
  472. graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node,
  473. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  474. std::set<NodePtr> &changed_nodes) {
  475. // check sub_graph shape. Get max for update.
  476. for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
  477. if (ref_out_tensors[i].empty()) {
  478. continue;
  479. }
  480. int64_t max_size = 0;
  481. size_t max_shape_index = 0;
  482. auto &ref_out_tensor = ref_out_tensors[i].at(0);
  483. for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
  484. auto &tensor = ref_out_tensors[i].at(j);
  485. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  486. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output",
  487. node->GetName().c_str());
  488. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output",
  489. node->GetName().c_str());
  490. return GRAPH_FAILED;
  491. }
  492. auto shape = tensor.MutableShape();
  493. int64_t size = 1;
  494. for (auto dim : shape.GetDims()) {
  495. if (dim != 0 && INT64_MAX / dim < size) {
  496. REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(),
  497. node->GetName().c_str());
  498. GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
  499. return PARAM_INVALID;
  500. }
  501. size *= dim;
  502. }
  503. if (size > max_size) {
  504. max_size = size;
  505. max_shape_index = j;
  506. }
  507. }
  508. //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
  509. bool output_changed = false;
  510. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  511. (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc,
  512. output_changed);
  513. if (output_changed) {
  514. changed_nodes.insert(node);
  515. }
  516. }
  517. return GRAPH_SUCCESS;
  518. }
  519. graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node,
  520. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  521. std::set<NodePtr> &changed_nodes) {
  522. GELOGD("Enter update parent node shape for class branch op process");
  523. if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
  524. return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes);
  525. }
  526. // check sub_graph shape.If not same ,do unknown shape process
  527. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  528. if (ref_out_tensors[i].empty()) {
  529. continue;
  530. }
  531. auto ref_out_tensor = ref_out_tensors[i].at(0);
  532. ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  533. for (auto &tensor : ref_out_tensors[i]) {
  534. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  535. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s",
  536. node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str());
  537. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str());
  538. return GRAPH_FAILED;
  539. }
  540. auto shape = tensor.MutableShape();
  541. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  542. GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
  543. shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  544. ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
  545. break;
  546. }
  547. for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
  548. if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
  549. continue;
  550. }
  551. GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(),
  552. i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  553. (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
  554. }
  555. }
  556. //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  557. bool output_changed = false;
  558. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  559. (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc, output_changed);
  560. if (output_changed) {
  561. changed_nodes.insert(node);
  562. }
  563. }
  564. return GRAPH_SUCCESS;
  565. }
  566. // just for merge, to be deleted
  567. graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  568. auto op_desc = node->GetOpDesc();
  569. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  570. if (is_unknown_graph) {
  571. return GRAPH_SUCCESS;
  572. }
  573. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  574. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  575. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  576. auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc();
  577. if (peer_anchor_opdesc == nullptr) {
  578. continue;
  579. }
  580. if (op_desc->GetId() < peer_anchor_opdesc->GetId() || peer_anchor_opdesc->GetType() == CONSTANT ||
  581. peer_anchor_opdesc->GetType() == CONSTANTOP) {
  582. continue;
  583. }
  584. auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx());
  585. if (peer_input_desc == nullptr) {
  586. continue;
  587. }
  588. bool changed = false;
  589. if (peer_input_desc->GetShape().GetDims() != output_tensor->GetShape().GetDims()) {
  590. changed = true;
  591. }
  592. peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
  593. peer_input_desc->SetShape(output_tensor->GetShape());
  594. peer_input_desc->SetDataType(output_tensor->GetDataType());
  595. peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
  596. std::vector<std::pair<int64_t, int64_t>> shape_range;
  597. (void)output_tensor->GetShapeRange(shape_range);
  598. peer_input_desc->SetShapeRange(shape_range);
  599. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  600. static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  601. if (changed) {
  602. changed_nodes.insert(peer_anchor->GetOwnerNode());
  603. }
  604. }
  605. }
  606. return GRAPH_SUCCESS;
  607. }
  608. void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) {
  609. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  610. return;
  611. }
  612. if (node == nullptr) {
  613. REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid");
  614. GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
  615. return;
  616. }
  617. ge::OpDescPtr op_desc = node->GetOpDesc();
  618. GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid");
  619. GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
  620. std::stringstream ss;
  621. ss << "{";
  622. int32_t in_idx = 0;
  623. int32_t out_idx = 0;
  624. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  625. if (input_desc == nullptr) {
  626. in_idx++;
  627. continue;
  628. }
  629. if (in_idx > 0) {
  630. ss << " ";
  631. }
  632. ss << "input_" << in_idx << " "
  633. << "tensor: [";
  634. ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
  635. ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
  636. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
  637. ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
  638. ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
  639. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
  640. string range_str;
  641. SerialShapeRange(input_desc, range_str);
  642. ss << "(shape_range:" << range_str << ")]";
  643. in_idx++;
  644. }
  645. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  646. if (output_desc == nullptr) {
  647. out_idx++;
  648. continue;
  649. }
  650. ss << " ";
  651. ss << "output_" << out_idx << " "
  652. << "tensor: [";
  653. ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),";
  654. ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),";
  655. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),";
  656. ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),";
  657. ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),";
  658. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),";
  659. string range_str;
  660. SerialShapeRange(output_desc, range_str);
  661. ss << "(shape_range:" << range_str << ")]";
  662. out_idx++;
  663. }
  664. ss << "}";
  665. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
  666. }
  667. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示