You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass.cc 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "infer_base_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/util/error_manager/error_manager.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/util.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/debug/ge_util.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/node_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. string Serial(const vector<int64_t> &dims) {
  30. string serial_string;
  31. serial_string += "[";
  32. for (int64_t dim : dims) {
  33. serial_string += std::to_string(dim) + " ";
  34. }
  35. serial_string += "]";
  36. return serial_string;
  37. }
  38. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  39. desc_str += "[";
  40. std::vector<std::pair<int64_t, int64_t>> shape_range;
  41. (void)desc->GetShapeRange(shape_range);
  42. for (const auto &pair : shape_range) {
  43. desc_str += "{";
  44. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  45. desc_str += "},";
  46. }
  47. desc_str += "]";
  48. shape_range.clear();
  49. (void)desc->GetOriginShapeRange(shape_range);
  50. for (const auto &pair : shape_range) {
  51. desc_str += ",{";
  52. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  53. desc_str += "},";
  54. }
  55. }
  56. graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput,
  57. const ConstNodePtr &node,
  58. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
  59. auto sub_nodes = sub_graph->GetDirectNode();
  60. for (size_t i = sub_nodes.size(); i > 0; --i) {
  61. auto sub_node = sub_nodes.at(i - 1);
  62. if (sub_node->GetType() == NETOUTPUT) {
  63. netoutput = sub_node;
  64. }
  65. if (sub_node->GetType() == DATA) {
  66. if (sub_node->GetOpDesc() == nullptr) {
  67. return GRAPH_FAILED;
  68. }
  69. int ref_i;
  70. if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  71. REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  72. GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  73. return GRAPH_FAILED;
  74. }
  75. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
  76. REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!",
  77. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  78. GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!",
  79. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  80. return GRAPH_FAILED;
  81. }
  82. ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
  83. }
  84. }
  85. return GRAPH_SUCCESS;
  86. }
  87. } // namespace
  88. Status InferBasePass::Run(NodePtr &node) {
  89. GE_CHECK_NOTNULL(node);
  90. GE_CHECK_NOTNULL(node->GetOpDesc());
  91. bool need_infer = NeedInfer(node);
  92. if (!need_infer) {
  93. GELOGD("Node %s does not need to infer.", node->GetName().c_str());
  94. return SUCCESS;
  95. }
  96. std::set<NodePtr> changed_nodes;
  97. auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
  98. if (ret != GRAPH_SUCCESS) {
  99. (void)AnalyzeFailedInfo(node);
  100. return GE_GRAPH_INFERSHAPE_FAILED;
  101. }
  102. // AddChangedNodesImmediateRepass(changed_nodes);
  103. auto status = DoRepassForLoopNode(node);
  104. if (status != SUCCESS) {
  105. GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "repass failed. node: %s", node->GetName().c_str());
  106. return GE_GRAPH_INFERSHAPE_FAILED;
  107. }
  108. return SUCCESS;
  109. }
  110. bool InferBasePass::NeedInfer(const NodePtr &node) { return true; }
  111. void InferBasePass::AnalyzeFailedInfo(const NodePtr &node) { /* Analyze and select failed info*/ }
  112. Status InferBasePass::DoRepassForLoopNode(NodePtr &node) { return SUCCESS; }
  113. void InferBasePass::AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes) {
  114. for (const auto &node_ele : changed_nodes) {
  115. AddImmediateRePassNode(node_ele);
  116. }
  117. }
  118. graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
  119. bool contain_subgraph = ContainsSubgraph(node);
  120. if (contain_subgraph && before_subgraph) {
  121. auto ret = UpdateTensorDescToSubgraphData(node, changed_nodes);
  122. if (ret != GRAPH_SUCCESS) {
  123. return ret;
  124. }
  125. }
  126. auto ret = Infer(node);
  127. if (ret != GRAPH_SUCCESS) {
  128. return ret;
  129. }
  130. if (contain_subgraph && !before_subgraph) {
  131. return UpdateTensorDescToParentNode(node, changed_nodes);
  132. }
  133. return UpdateTensorDescToPeerInputs(node, changed_nodes);
  134. }
  135. bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
  136. auto op_desc = node->GetOpDesc();
  137. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  138. if (sub_graph_names.empty()) {
  139. return false;
  140. }
  141. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  142. for (const auto &name : sub_graph_names) {
  143. if (name.empty()) {
  144. continue;
  145. }
  146. auto sub_graph = root_graph->GetSubgraph(name);
  147. if (sub_graph != nullptr) {
  148. return true;
  149. }
  150. }
  151. return false;
  152. }
  153. graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  154. // if infer again, update output of while into subgraph data node
  155. auto op_desc = node->GetOpDesc();
  156. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  157. if (sub_graph_names.empty()) {
  158. return GRAPH_SUCCESS;
  159. }
  160. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  161. for (const auto &name : sub_graph_names) {
  162. if (name.empty()) {
  163. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  164. continue;
  165. }
  166. auto sub_graph = root_graph->GetSubgraph(name);
  167. if (sub_graph == nullptr) {
  168. REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  169. GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  170. return GRAPH_FAILED;
  171. }
  172. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  173. if (node_sub->GetType() != DATA) {
  174. continue;
  175. }
  176. int ref_i;
  177. auto data_opdesc = node_sub->GetOpDesc();
  178. if (data_opdesc == nullptr) {
  179. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  180. node->GetName().c_str());
  181. GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  182. node->GetName().c_str());
  183. return GRAPH_FAILED;
  184. }
  185. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  186. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  187. name.c_str(), node->GetName().c_str());
  188. GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  189. node->GetName().c_str());
  190. return GRAPH_FAILED;
  191. }
  192. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  193. continue;
  194. }
  195. auto input_desc = op_desc->MutableInputDesc(ref_i);
  196. if (input_desc == nullptr) {
  197. REPORT_INNER_ERROR("E19999",
  198. "The ref index(%d) on the data %s on the sub graph %s "
  199. "parent node %s are incompatible, inputs num %u",
  200. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  201. node->GetAllInDataAnchorsSize());
  202. GE_LOGE(
  203. "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
  204. "parent node %s are incompatible, inputs num %u",
  205. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize());
  206. return GRAPH_FAILED;
  207. }
  208. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  209. node->GetName().c_str());
  210. // if need infer again, refresh subgraph input with output
  211. bool is_infer_again = false;
  212. AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again);
  213. if (is_infer_again) {
  214. input_desc = op_desc->MutableOutputDesc(ref_i);
  215. if (input_desc == nullptr) {
  216. REPORT_INNER_ERROR("E19999",
  217. "The ref index(%d) on the data %s on the subgraph %s "
  218. "parent node %s are incompatible, outputs num %u.",
  219. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  220. node->GetAllOutDataAnchorsSize());
  221. GELOGE(PARAM_INVALID,
  222. "[Call][MutableOutputDesc] The ref index(%d) on the data %s on the subgraph %s "
  223. "parent node %s are incompatible, outputs num %u.",
  224. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  225. node->GetAllOutDataAnchorsSize());
  226. }
  227. GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]",
  228. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), ref_i,
  229. data_opdesc->GetInputDescPtr(0)->GetShape().ToString().c_str(),
  230. input_desc->GetShape().ToString().c_str());
  231. }
  232. // auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  233. bool input_changed = false;
  234. auto data_input_desc = data_opdesc->MutableInputDesc(0);
  235. auto ret = UpdateTensorDesc(input_desc, data_input_desc,input_changed);
  236. if (ret != GRAPH_SUCCESS) {
  237. REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
  238. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  239. GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(),
  240. name.c_str(), node->GetName().c_str());
  241. return ret;
  242. }
  243. // ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  244. bool output_changed = false;
  245. auto data_output_desc = data_opdesc->MutableOutputDesc(0);
  246. ret = UpdateTensorDesc(input_desc, data_output_desc,output_changed);
  247. if (ret != GRAPH_SUCCESS) {
  248. REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
  249. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  250. GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
  251. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  252. return ret;
  253. }
  254. if (input_changed || output_changed) {
  255. changed_nodes.insert(node_sub);
  256. }
  257. }
  258. }
  259. return GRAPH_SUCCESS;
  260. }
  261. graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  262. auto op_desc = node->GetOpDesc();
  263. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  264. if (sub_graph_names.empty()) {
  265. return GRAPH_SUCCESS;
  266. }
  267. std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
  268. std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  269. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  270. for (const auto &name : sub_graph_names) {
  271. if (name.empty()) {
  272. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  273. continue;
  274. }
  275. auto sub_graph = root_graph->GetSubgraph(name);
  276. if (sub_graph == nullptr) {
  277. REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  278. GE_LOGE("[Get][Subgraph] Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  279. return GRAPH_FAILED;
  280. }
  281. NodePtr netoutput = nullptr;
  282. auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
  283. if (ret != GRAPH_SUCCESS) {
  284. return ret;
  285. }
  286. if (netoutput == nullptr) {
  287. REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  288. node->GetName().c_str());
  289. GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  290. node->GetName().c_str());
  291. return GRAPH_FAILED;
  292. }
  293. auto netoutput_opdesc = netoutput->GetOpDesc();
  294. if (netoutput_opdesc == nullptr) {
  295. REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
  296. name.c_str(), node->GetName().c_str());
  297. GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  298. node->GetName().c_str());
  299. return GRAPH_FAILED;
  300. }
  301. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  302. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  303. if (edge_desc == nullptr) {
  304. REPORT_INNER_ERROR("E19999",
  305. "Invalid NetOutput node on sub graph %s, parent node %s, "
  306. "can not find input tensor %d",
  307. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  308. GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  309. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  310. return GRAPH_FAILED;
  311. }
  312. GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(),
  313. edge_desc->GetShape().GetDimNum());
  314. int ref_i;
  315. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  316. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  317. continue;
  318. }
  319. GELOGI("Parent node index of edge desc is %d", ref_i);
  320. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  321. return GRAPH_FAILED;
  322. }
  323. ref_out_tensors[ref_i].emplace_back(*edge_desc);
  324. }
  325. }
  326. if (node->GetType() == WHILE) {
  327. return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes);
  328. }
  329. return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes);
  330. }
  331. graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node,
  332. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
  333. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  334. std::set<NodePtr> &changed_nodes) {
  335. GELOGD("Enter update parent node shape for class while op process");
  336. if (ref_data_tensors.size() != ref_out_tensors.size()) {
  337. REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!",
  338. node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(),
  339. ref_out_tensors.size());
  340. GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!",
  341. node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
  342. return GRAPH_FAILED;
  343. }
  344. for (size_t i = 0; i < ref_data_tensors.size(); i++) {
  345. if (ref_out_tensors[i].size() != 1) {
  346. REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!");
  347. GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!");
  348. return GRAPH_FAILED;
  349. }
  350. }
  351. bool need_infer_again = false;
  352. // check input and output
  353. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  354. if (ref_out_tensors[i].empty()) {
  355. continue;
  356. }
  357. auto ref_out_tensor = ref_out_tensors[i].at(0);
  358. auto out_shape = ref_out_tensor.MutableShape();
  359. vector<std::pair<int64_t, int64_t>> data_shape_range;
  360. // ref_i's data and output tensor shape should be same
  361. for (auto &tensor : ref_data_tensors[i]) {
  362. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  363. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output",
  364. node->GetName().c_str());
  365. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.",
  366. node->GetName().c_str());
  367. return GRAPH_FAILED;
  368. }
  369. auto data_shape = tensor.MutableShape();
  370. // input is dynamic, here use dim_num
  371. if (data_shape.GetDims() != out_shape.GetDims()) {
  372. GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.",
  373. node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str());
  374. if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
  375. ref_out_tensor.SetUnknownDimNumShape();
  376. } else {
  377. for (size_t j = 0; j < data_shape.GetDimNum(); ++j) {
  378. if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
  379. if (data_shape.GetDim(j) != UNKNOWN_DIM) {
  380. // if input data is fix shape, output is different, need_infer_again
  381. need_infer_again = true;
  382. }
  383. data_shape.SetDim(j, UNKNOWN_DIM);
  384. }
  385. // set shape rang of while, if dim is unknown ,set shape range as {1,-1}
  386. if (data_shape.GetDim(j) == UNKNOWN_DIM) {
  387. data_shape_range.emplace_back(std::make_pair(1, UNKNOWN_DIM));
  388. } else {
  389. data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)));
  390. }
  391. }
  392. ref_out_tensor.SetShape(data_shape);
  393. ref_out_tensor.SetShapeRange(data_shape_range);
  394. }
  395. }
  396. }
  397. //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  398. bool output_changed = false;
  399. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  400. (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc,output_changed);
  401. if (output_changed) {
  402. changed_nodes.insert(node);
  403. }
  404. }
  405. AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again);
  406. return GRAPH_SUCCESS;
  407. }
  408. graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node,
  409. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  410. std::set<NodePtr> &changed_nodes) {
  411. // check sub_graph shape. Get max for update.
  412. for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
  413. if (ref_out_tensors[i].empty()) {
  414. continue;
  415. }
  416. int64_t max_size = 0;
  417. size_t max_shape_index = 0;
  418. auto &ref_out_tensor = ref_out_tensors[i].at(0);
  419. for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
  420. auto &tensor = ref_out_tensors[i].at(j);
  421. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  422. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output",
  423. node->GetName().c_str());
  424. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output",
  425. node->GetName().c_str());
  426. return GRAPH_FAILED;
  427. }
  428. auto shape = tensor.MutableShape();
  429. int64_t size = 1;
  430. for (auto dim : shape.GetDims()) {
  431. if (dim != 0 && INT64_MAX / dim < size) {
  432. REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(),
  433. node->GetName().c_str());
  434. GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
  435. return PARAM_INVALID;
  436. }
  437. size *= dim;
  438. }
  439. if (size > max_size) {
  440. max_size = size;
  441. max_shape_index = j;
  442. }
  443. }
  444. //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
  445. bool output_changed = false;
  446. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  447. (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc,
  448. output_changed);
  449. if (output_changed) {
  450. changed_nodes.insert(node);
  451. }
  452. }
  453. return GRAPH_SUCCESS;
  454. }
  455. graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node,
  456. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  457. std::set<NodePtr> &changed_nodes) {
  458. GELOGD("Enter update parent node shape for class branch op process");
  459. if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
  460. return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes);
  461. }
  462. // check sub_graph shape.If not same ,do unknown shape process
  463. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  464. if (ref_out_tensors[i].empty()) {
  465. continue;
  466. }
  467. auto ref_out_tensor = ref_out_tensors[i].at(0);
  468. ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  469. for (auto &tensor : ref_out_tensors[i]) {
  470. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  471. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s",
  472. node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str());
  473. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str());
  474. return GRAPH_FAILED;
  475. }
  476. auto shape = tensor.MutableShape();
  477. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  478. GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
  479. shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  480. ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
  481. break;
  482. }
  483. for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
  484. if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
  485. continue;
  486. }
  487. GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(),
  488. i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  489. (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
  490. }
  491. }
  492. //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  493. bool output_changed = false;
  494. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  495. (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc, output_changed);
  496. if (output_changed) {
  497. changed_nodes.insert(node);
  498. }
  499. }
  500. return GRAPH_SUCCESS;
  501. }
  502. graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  503. auto op_desc = node->GetOpDesc();
  504. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  505. if (is_unknown_graph) {
  506. return GRAPH_SUCCESS;
  507. }
  508. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  509. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  510. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  511. auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc();
  512. if (peer_anchor_opdesc == nullptr) {
  513. continue;
  514. }
  515. if (op_desc->GetId() < peer_anchor_opdesc->GetId() || peer_anchor_opdesc->GetType() == CONSTANT ||
  516. peer_anchor_opdesc->GetType() == CONSTANTOP) {
  517. continue;
  518. }
  519. auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx());
  520. if (peer_input_desc == nullptr) {
  521. continue;
  522. }
  523. bool changed = false;
  524. auto ret = UpdateDescAttrForPeerInput(output_tensor, peer_input_desc, changed);
  525. if (ret != GRAPH_SUCCESS) {
  526. REPORT_CALL_ERROR("E19999", "Failed to update peer tensor desc attr");
  527. GE_LOGE("[Update][PeerInputDesc] Failed to update peer tensor desc attr");
  528. return ret;
  529. }
  530. if (changed) {
  531. changed_nodes.insert(peer_anchor->GetOwnerNode());
  532. }
  533. }
  534. }
  535. PrintInOutTensorShape(node, "after_infer");
  536. return GRAPH_SUCCESS;
  537. }
  538. graphStatus InferBasePass::UpdateDescAttrForPeerInput(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed){
  539. changed = false;
  540. return GRAPH_SUCCESS;
  541. }
  542. void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) {
  543. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  544. return;
  545. }
  546. if (node == nullptr) {
  547. REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid");
  548. GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
  549. return;
  550. }
  551. ge::OpDescPtr op_desc = node->GetOpDesc();
  552. GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid");
  553. GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
  554. std::stringstream ss;
  555. ss << "{";
  556. int32_t in_idx = 0;
  557. int32_t out_idx = 0;
  558. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  559. if (input_desc == nullptr) {
  560. in_idx++;
  561. continue;
  562. }
  563. if (in_idx > 0) {
  564. ss << " ";
  565. }
  566. ss << "input_" << in_idx << " "
  567. << "tensor: [";
  568. ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
  569. ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
  570. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
  571. ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
  572. ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
  573. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
  574. string range_str;
  575. SerialShapeRange(input_desc, range_str);
  576. ss << "(shape_range:" << range_str << ")]";
  577. in_idx++;
  578. }
  579. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  580. if (output_desc == nullptr) {
  581. out_idx++;
  582. continue;
  583. }
  584. ss << " ";
  585. ss << "output_" << out_idx << " "
  586. << "tensor: [";
  587. ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),";
  588. ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),";
  589. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),";
  590. ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),";
  591. ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),";
  592. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),";
  593. string range_str;
  594. SerialShapeRange(output_desc, range_str);
  595. ss << "(shape_range:" << range_str << ")]";
  596. out_idx++;
  597. }
  598. ss << "}";
  599. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
  600. }
  601. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示