You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass.cc 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "infer_base_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/formats/utils/formats_trans_utils.h"
  19. #include "common/util/error_manager/error_manager.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "framework/common/util.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/debug/ge_util.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/node_utils.h"
  26. #include "graph/utils/tensor_utils.h"
  27. #include "graph/utils/type_utils.h"
  28. namespace ge {
  29. namespace {
  30. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  31. std::vector<std::pair<int64_t, int64_t>> shape_range;
  32. (void)desc->GetShapeRange(shape_range);
  33. desc_str += formats::RangeToString(shape_range);
  34. shape_range.clear();
  35. (void)desc->GetOriginShapeRange(shape_range);
  36. desc_str += ",";
  37. desc_str += formats::RangeToString(shape_range);
  38. shape_range.clear();
  39. }
  40. graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node,
  41. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
  42. auto sub_nodes = sub_graph->GetDirectNode();
  43. for (size_t i = sub_nodes.size(); i > 0; --i) {
  44. auto sub_node = sub_nodes.at(i - 1);
  45. if (sub_node->GetType() == NETOUTPUT) {
  46. netoutput = sub_node;
  47. }
  48. if (sub_node->GetType() == DATA) {
  49. if (sub_node->GetOpDesc() == nullptr) {
  50. return GRAPH_FAILED;
  51. }
  52. int ref_i;
  53. if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  54. REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  55. GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  56. return GRAPH_FAILED;
  57. }
  58. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
  59. REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!",
  60. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  61. GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!",
  62. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  63. return GRAPH_FAILED;
  64. }
  65. ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
  66. }
  67. }
  68. return GRAPH_SUCCESS;
  69. }
  70. } // namespace
  71. Status InferBasePass::Run(NodePtr &node) {
  72. GE_CHECK_NOTNULL(node);
  73. GE_CHECK_NOTNULL(node->GetOpDesc());
  74. bool need_infer = NeedInfer(node);
  75. if (!need_infer) {
  76. GELOGD("Node %s does not need to infer.", node->GetName().c_str());
  77. return SUCCESS;
  78. }
  79. std::set<NodePtr> changed_nodes;
  80. auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
  81. if (ret != GRAPH_SUCCESS) {
  82. GELOGE(ret, "Infer and update for node %s failed! ret: %u", node->GetName().c_str(), ret);
  83. return GRAPH_FAILED;
  84. }
  85. AddChangedNodesImmediateRepass(changed_nodes);
  86. return SUCCESS;
  87. }
  88. bool InferBasePass::NeedInfer(const NodePtr &node) { return true; }
  89. void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) {
  90. for (const auto &node_ele : changed_nodes) {
  91. AddImmediateRePassNode(node_ele);
  92. }
  93. }
  94. graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
  95. graphStatus ret ;
  96. bool contain_subgraph = ContainsSubgraph(node);
  97. if (contain_subgraph && before_subgraph) {
  98. ret = UpdateTensorDescToSubgraphData(node, changed_nodes);
  99. if (ret != GRAPH_SUCCESS) {
  100. GELOGE(ret, "Update subgraph data tensor desc for node %s failed! ret: %u", node->GetName().c_str(), ret);
  101. return ret;
  102. }
  103. }
  104. ret = Infer(node);
  105. if (ret != GRAPH_SUCCESS) {
  106. GELOGE(ret, "Infer failed for node %s, ret: %u", node->GetName().c_str(), ret);
  107. return ret;
  108. }
  109. if (contain_subgraph && !before_subgraph) {
  110. ret = UpdateTensorDescToParentNode(node, changed_nodes);
  111. if (ret != GRAPH_SUCCESS) {
  112. GELOGE(ret, "Update parent tensor desc for node %s failed! ret: %u", node->GetName().c_str(), ret);
  113. return ret;
  114. }
  115. }
  116. ret = UpdateTensorDescToPeerInputs(node, changed_nodes);
  117. if (ret != GRAPH_SUCCESS) {
  118. GELOGE(ret, "Node %s updates tensor desc to peer input nodes failed! ret: %u", node->GetName().c_str(), ret);
  119. }
  120. return ret;
  121. }
  122. bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
  123. auto op_desc = node->GetOpDesc();
  124. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  125. if (sub_graph_names.empty()) {
  126. return false;
  127. }
  128. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  129. if (root_graph == nullptr) {
  130. return false;
  131. }
  132. for (const auto &name : sub_graph_names) {
  133. if (name.empty()) {
  134. continue;
  135. }
  136. auto sub_graph = root_graph->GetSubgraph(name);
  137. if (sub_graph != nullptr) {
  138. return true;
  139. }
  140. }
  141. return false;
  142. }
  143. graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  144. PrintInOutTensorShape(node, "after_infer");
  145. auto op_desc = node->GetOpDesc();
  146. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  147. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  148. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  149. auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc();
  150. if (peer_anchor_opdesc == nullptr) {
  151. continue;
  152. }
  153. auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx());
  154. if (peer_input_desc == nullptr) {
  155. continue;
  156. }
  157. bool changed = false;
  158. auto ret = UpdatePeerInputDesc(output_tensor, peer_input_desc, changed);
  159. if (ret != GRAPH_SUCCESS) {
  160. REPORT_CALL_ERROR("E19999", "Update peer input desc failed, node %s.", node->GetName().c_str());
  161. GELOGE(ret, "Update peer input desc failed, node %s.", node->GetName().c_str());
  162. return ret;
  163. }
  164. if (changed) {
  165. changed_nodes.insert(peer_anchor->GetOwnerNode());
  166. }
  167. }
  168. }
  169. return GRAPH_SUCCESS;
  170. }
  171. graphStatus InferBasePass::UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  172. changed = false;
  173. return GRAPH_SUCCESS;
  174. }
  175. std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) {
  176. std::vector<ComputeGraphPtr> cur_node_subgraph;
  177. auto op_desc = node->GetOpDesc();
  178. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  179. if (sub_graph_names.empty()) {
  180. return cur_node_subgraph;
  181. }
  182. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  183. for (const auto &name : sub_graph_names) {
  184. if (name.empty()) {
  185. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  186. continue;
  187. }
  188. auto sub_graph = root_graph->GetSubgraph(name);
  189. if (sub_graph == nullptr) {
  190. REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  191. GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  192. continue;
  193. }
  194. cur_node_subgraph.emplace_back(sub_graph);
  195. }
  196. return cur_node_subgraph;
  197. }
  198. graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  199. // if infer again, update output of while into subgraph data node
  200. auto op_desc = node->GetOpDesc();
  201. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  202. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  203. if (node_sub->GetType() != DATA) {
  204. continue;
  205. }
  206. auto name = sub_graph->GetName();
  207. int ref_i;
  208. auto data_opdesc = node_sub->GetOpDesc();
  209. if (data_opdesc == nullptr) {
  210. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  211. node->GetName().c_str());
  212. GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  213. node->GetName().c_str());
  214. return GRAPH_FAILED;
  215. }
  216. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  217. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  218. name.c_str(), node->GetName().c_str());
  219. GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  220. node->GetName().c_str());
  221. return GRAPH_FAILED;
  222. }
  223. // In multi-batch, data shape of subgraph is different, no need to refresh.
  224. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  225. continue;
  226. }
  227. auto input_desc = op_desc->MutableInputDesc(ref_i);
  228. if (input_desc == nullptr) {
  229. REPORT_INNER_ERROR("E19999",
  230. "The ref index(%d) on the data %s on the sub graph %s "
  231. "parent node %s are incompatible, inputs num %u",
  232. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  233. node->GetAllInDataAnchorsSize());
  234. GE_LOGE(
  235. "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
  236. "parent node %s are incompatible, inputs num %u",
  237. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize());
  238. return GRAPH_FAILED;
  239. }
  240. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  241. node->GetName().c_str());
  242. auto data_input_desc = data_opdesc->MutableInputDesc(0);
  243. if (!SameTensorDesc(input_desc, data_input_desc)) {
  244. changed_nodes.insert(node_sub);
  245. // if need infer again, refresh while subgraph input with while output
  246. if (node->GetType() == WHILE) {
  247. input_desc = op_desc->MutableOutputDesc(ref_i);
  248. }
  249. }
  250. auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  251. if (ret != GRAPH_SUCCESS) {
  252. REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
  253. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  254. GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(),
  255. name.c_str(), node->GetName().c_str());
  256. return ret;
  257. }
  258. ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  259. if (ret != GRAPH_SUCCESS) {
  260. REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
  261. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  262. GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
  263. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  264. return ret;
  265. }
  266. }
  267. }
  268. return GRAPH_SUCCESS;
  269. }
  270. graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  271. std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
  272. std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  273. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  274. auto name = sub_graph->GetName();
  275. NodePtr netoutput = nullptr;
  276. auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
  277. if (ret != GRAPH_SUCCESS) {
  278. return ret;
  279. }
  280. if (netoutput == nullptr) {
  281. REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  282. node->GetName().c_str());
  283. GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  284. node->GetName().c_str());
  285. return GRAPH_FAILED;
  286. }
  287. auto netoutput_opdesc = netoutput->GetOpDesc();
  288. if (netoutput_opdesc == nullptr) {
  289. REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
  290. name.c_str(), node->GetName().c_str());
  291. GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  292. node->GetName().c_str());
  293. return GRAPH_FAILED;
  294. }
  295. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  296. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  297. if (edge_desc == nullptr) {
  298. REPORT_INNER_ERROR("E19999",
  299. "Invalid NetOutput node on sub graph %s, parent node %s, "
  300. "can not find input tensor %d",
  301. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  302. GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  303. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  304. return GRAPH_FAILED;
  305. }
  306. GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(),
  307. edge_desc->GetShape().GetDimNum());
  308. int ref_i;
  309. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  310. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  311. continue;
  312. }
  313. GELOGI("Parent node index of edge desc is %d", ref_i);
  314. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  315. return GRAPH_FAILED;
  316. }
  317. ref_out_tensors[ref_i].emplace_back(*edge_desc);
  318. }
  319. }
  320. if (node->GetType() == WHILE) {
  321. return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes);
  322. }
  323. return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes);
  324. }
  325. graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node,
  326. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
  327. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  328. std::set<NodePtr> &changed_nodes) {
  329. GELOGD("Enter update parent node shape for class while op process");
  330. if (ref_data_tensors.size() != ref_out_tensors.size()) {
  331. REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!",
  332. node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(),
  333. ref_out_tensors.size());
  334. GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!",
  335. node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
  336. return GRAPH_FAILED;
  337. }
  338. // check input and output
  339. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  340. if (ref_out_tensors[i].size() != 1) {
  341. REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!");
  342. GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!");
  343. return GRAPH_FAILED;
  344. }
  345. auto ref_out_tensor = ref_out_tensors[i].at(0);
  346. for (auto &tensor : ref_data_tensors[i]) {
  347. // ref_i's data and output tensor shape should be same
  348. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  349. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output",
  350. node->GetName().c_str());
  351. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.",
  352. node->GetName().c_str());
  353. return GRAPH_FAILED;
  354. }
  355. auto data_shape = tensor.MutableShape();
  356. auto out_shape = ref_out_tensor.MutableShape();
  357. if (data_shape.GetDims() != out_shape.GetDims()) {
  358. GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.",
  359. node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str());
  360. if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
  361. ref_out_tensor.SetUnknownDimNumShape();
  362. } else {
  363. size_t data_dim_num = data_shape.GetDimNum();
  364. std::vector<std::pair<int64_t, int64_t>> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)};
  365. for (size_t j = 0; j < data_dim_num; ++j) {
  366. if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
  367. data_shape.SetDim(j, UNKNOWN_DIM);
  368. }
  369. if (data_shape.GetDim(j) != UNKNOWN_DIM) {
  370. data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j));
  371. }
  372. }
  373. ref_out_tensor.SetShape(data_shape);
  374. ref_out_tensor.SetShapeRange(data_shape_range);
  375. }
  376. }
  377. }
  378. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  379. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  380. bool output_changed = SameTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
  381. if (output_changed) {
  382. changed_nodes.insert(node);
  383. }
  384. }
  385. return GRAPH_SUCCESS;
  386. }
  387. graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node,
  388. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  389. std::set<NodePtr> &changed_nodes) {
  390. // check sub_graph shape. Get max for update.
  391. for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
  392. if (ref_out_tensors[i].empty()) {
  393. continue;
  394. }
  395. int64_t max_size = 0;
  396. size_t max_shape_index = 0;
  397. auto &ref_out_tensor = ref_out_tensors[i].at(0);
  398. for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
  399. auto &tensor = ref_out_tensors[i].at(j);
  400. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  401. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output",
  402. node->GetName().c_str());
  403. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output",
  404. node->GetName().c_str());
  405. return GRAPH_FAILED;
  406. }
  407. auto shape = tensor.MutableShape();
  408. int64_t size = 1;
  409. for (auto dim : shape.GetDims()) {
  410. if (dim != 0 && INT64_MAX / dim < size) {
  411. REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(),
  412. node->GetName().c_str());
  413. GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
  414. return PARAM_INVALID;
  415. }
  416. size *= dim;
  417. }
  418. if (size > max_size) {
  419. max_size = size;
  420. max_shape_index = j;
  421. }
  422. }
  423. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  424. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
  425. bool output_changed =
  426. SameTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc);
  427. if (output_changed) {
  428. changed_nodes.insert(node);
  429. }
  430. }
  431. return GRAPH_SUCCESS;
  432. }
  433. graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node,
  434. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  435. std::set<NodePtr> &changed_nodes) {
  436. GELOGD("Enter update parent node shape for class branch op process");
  437. if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
  438. return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes);
  439. }
  440. // check sub_graph shape.If not same ,do unknown shape process
  441. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  442. if (ref_out_tensors[i].empty()) {
  443. continue;
  444. }
  445. auto ref_out_tensor = ref_out_tensors[i].at(0);
  446. ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  447. for (auto &tensor : ref_out_tensors[i]) {
  448. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  449. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s",
  450. node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str());
  451. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str());
  452. return GRAPH_FAILED;
  453. }
  454. auto shape = tensor.MutableShape();
  455. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  456. GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
  457. shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  458. ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
  459. break;
  460. }
  461. for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
  462. if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
  463. continue;
  464. }
  465. GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(),
  466. i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  467. (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
  468. }
  469. }
  470. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  471. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  472. bool output_changed =
  473. SameTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
  474. if (output_changed) {
  475. changed_nodes.insert(node);
  476. }
  477. }
  478. return GRAPH_SUCCESS;
  479. }
  480. void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) {
  481. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  482. return;
  483. }
  484. if (node == nullptr) {
  485. REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid");
  486. GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
  487. return;
  488. }
  489. ge::OpDescPtr op_desc = node->GetOpDesc();
  490. GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid");
  491. GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
  492. std::stringstream ss;
  493. ss << "{";
  494. int32_t in_idx = 0;
  495. int32_t out_idx = 0;
  496. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  497. if (input_desc == nullptr) {
  498. in_idx++;
  499. continue;
  500. }
  501. if (in_idx > 0) {
  502. ss << " ";
  503. }
  504. ss << "input_" << in_idx << " "
  505. << "tensor: [";
  506. ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
  507. ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
  508. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
  509. ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
  510. ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
  511. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
  512. string range_str;
  513. SerialShapeRange(input_desc, range_str);
  514. ss << "(shape_range:" << range_str << "),";
  515. std::vector<std::pair<int64_t, int64_t>> value_range;
  516. (void)input_desc->GetValueRange(value_range);
  517. string value_range_str = formats::RangeToString(value_range);
  518. ss << "(value_range:" << value_range_str << ")]";
  519. in_idx++;
  520. }
  521. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  522. if (output_desc == nullptr) {
  523. out_idx++;
  524. continue;
  525. }
  526. ss << " ";
  527. ss << "output_" << out_idx << " "
  528. << "tensor: [";
  529. ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),";
  530. ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),";
  531. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),";
  532. ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),";
  533. ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),";
  534. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),";
  535. string range_str;
  536. SerialShapeRange(output_desc, range_str);
  537. ss << "(shape_range:" << range_str << "),";
  538. std::vector<std::pair<int64_t, int64_t>> value_range;
  539. (void)output_desc->GetValueRange(value_range);
  540. string value_range_str = formats::RangeToString(value_range);
  541. ss << "(value_range:" << value_range_str << ")]";
  542. out_idx++;
  543. }
  544. ss << "}";
  545. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
  546. }
  547. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示