You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_clone_pass.cc 49 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_clone_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/ge/ge_util.h"
  19. #include "graph/common/local_context.h"
  20. #include "graph/preprocess/multi_batch_options.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "graph/utils/tensor_utils.h"
  24. #include "graph/utils/type_utils.h"
  25. #include "register/op_registry.h"
  26. #include "graph/common/omg_util.h"
  27. namespace ge {
  28. namespace {
  29. constexpr uint8_t kDataInIndex = 0;
  30. constexpr uint8_t kDataOutIndex = 0;
  31. constexpr uint8_t kCaseArgIndex = 1;
  32. const int kDivisionConst = 2;
  33. const size_t kNumOfGetnextNode = 1;
  34. const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case";
  35. const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data";
  36. const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node";
  37. const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
  38. const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
  39. const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_";
  40. const char *const kGetNextName = "IteratorV2";
  41. } // namespace
  42. inline bool IsGetNextType(const NodePtr &node) {
  43. std::string original_type;
  44. GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
  45. GELOGW("Get original type failed."); return false);
  46. return (original_type == kGetNextName);
  47. }
  48. Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
  49. GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(FAILED, "Original graph is nullptr"); return FAILED);
  50. if (graph->GetParentGraph() != nullptr) {
  51. GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str());
  52. return SUCCESS;
  53. }
  54. if (!GetLocalOmgContext().need_multi_batch) {
  55. GELOGI("No need to process_multi for no_train graph.");
  56. return SUCCESS;
  57. }
  58. std::vector<NodePtr> data_nodes;
  59. std::vector<NodePtr> getnext_nosink_nodes;
  60. std::vector<NodePtr> getnext_sink_nodes;
  61. if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  62. GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed.");
  63. return PARAM_INVALID;
  64. }
  65. if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  66. GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed.");
  67. return PARAM_INVALID;
  68. }
  69. if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) {
  70. GELOGE(PARAM_INVALID, "[Train_Dynamic] DeleteIdentityInsertByAdapter failed.");
  71. return PARAM_INVALID;
  72. }
  73. if (!multibatch::InitDynamicParams(batch_shapes_)) {
  74. GELOGD("There is no multi-batch options, no need clone multi-batch graph");
  75. return SUCCESS;
  76. }
  77. if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) {
  78. GELOGE(PARAM_INVALID, "[Train_Dynamic] Input_shape and dynamic_dims should set correct params.");
  79. return PARAM_INVALID;
  80. }
  81. GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str());
  82. GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param");
  83. if (CollectIoNodes(graph) != SUCCESS) {
  84. GELOGE(INTERNAL_ERROR, "Collect input output nodes failed");
  85. return INTERNAL_ERROR;
  86. }
  87. // parser data dynamic info from atc parameter --input_shape
  88. if (multibatch::ParserDataToDynmaicInfo(batch_shapes_, GetLocalOmgContext().user_input_dims,
  89. data_to_dynamic_info_) != SUCCESS) {
  90. GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed");
  91. return PARAM_INVALID;
  92. }
  93. (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  94. ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName());
  95. GE_IF_BOOL_EXEC(branch == nullptr, GELOGE(OUT_OF_MEMORY, "Create multi batch graph failed"); return OUT_OF_MEMORY);
  96. (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  97. graph->InValid(); // Will modify, need topological again.
  98. graph->Swap(*branch);
  99. GE_CHK_STATUS_RET(CreateRootGraph(graph), "Construct root graph failed.");
  100. GE_CHK_STATUS_RET(CreateOriGraph(branch), "Construct original graph failed.")
  101. GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed.");
  102. GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed");
  103. GELOGD("MultiBatchClonePass Leave");
  104. return SUCCESS;
  105. }
  106. ///
  107. /// @ingroup ge
  108. /// @brief Collect input output node from original graph.
  109. /// @param [in] const ComputeGraphPtr &graph: original graph.
  110. /// @return 0: SUCCESS / others: FAILED
  111. ///
  112. Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
  113. for (const auto &node : graph->GetDirectNode()) {
  114. if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) {
  115. all_data_nodes_.emplace_back(node);
  116. GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "Init params of %s failed.", node->GetName().c_str());
  117. }
  118. if (node->GetType() == DATA) {
  119. all_data_nodes_.emplace_back(node);
  120. } else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) {
  121. all_const_nodes_.emplace_back(node);
  122. } else if (node->GetType() == NETOUTPUT) {
  123. all_output_nodes_.emplace_back(node);
  124. }
  125. // If the node save as input/output node, delete record.
  126. (void)graph->RemoveInputNode(node);
  127. (void)graph->RemoveOutputNode(node);
  128. }
  129. if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) {
  130. GELOGE(FAILED, "data nodes: %zu, output nodes: %zu", all_data_nodes_.size(), all_output_nodes_.size());
  131. return FAILED;
  132. }
  133. int64_t data_index = 0;
  134. size_t getnext_node_count = 0;
  135. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  136. if (IsGetNextType(all_data_nodes_[i])) {
  137. // just one getnext node in graph
  138. getnext_node_count++;
  139. continue;
  140. }
  141. const auto &op_desc = all_data_nodes_[i]->GetOpDesc();
  142. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) {
  143. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count);
  144. }
  145. }
  146. const auto &output = all_output_nodes_[0];
  147. for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) {
  148. const auto in_anchor = output->GetInDataAnchor(i);
  149. const auto out_anchor = in_anchor->GetPeerOutAnchor();
  150. const auto data_node = out_anchor->GetOwnerNode();
  151. if (data_node->GetType() == DATA) {
  152. direct_output_[i] = data_node->GetName();
  153. GE_CHK_GRAPH_STATUS_RET(
  154. GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex), output->GetInDataAnchor(i)),
  155. "Remove edge failed");
  156. }
  157. }
  158. GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.",
  159. all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(),
  160. direct_output_.size());
  161. return SUCCESS;
  162. }
  163. Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) {
  164. data_count_from_getnext_ = 0;
  165. getnext_sink_dynamic_dims_ = false;
  166. GE_CHECK_NOTNULL(node->GetOpDesc());
  167. data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize();
  168. if (GetLocalOmgContext().dynamic_node_type == GETNEXT) {
  169. data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst;
  170. for (size_t i = 0; i < data_count_from_getnext_; ++i) {
  171. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i);
  172. GELOGD("The %zu data shape from getnext sink is %s.", i,
  173. formats::JoinToString(output_desc.GetShape().GetDims()).c_str());
  174. const auto &dims = output_desc.GetShape().GetDims();
  175. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) {
  176. GELOGD("The %zu data from %s is static.", i, node->GetName().c_str());
  177. } else {
  178. getnext_sink_dynamic_dims_ = true;
  179. GELOGD("Dynamic dims in the pattern of getnext sink.");
  180. }
  181. }
  182. }
  183. if (node->GetOutControlAnchor() != nullptr) {
  184. for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) {
  185. NodePtr next_node = peer_in_control_anchor->GetOwnerNode();
  186. GE_CHECK_NOTNULL(next_node);
  187. if (next_node->GetType() == CONSTANTOP) {
  188. out_control_nodes_.insert(next_node);
  189. GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str());
  190. }
  191. }
  192. }
  193. return SUCCESS;
  194. }
  195. ///
  196. /// @ingroup ge
  197. /// @brief Create nodes for root graph.
  198. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  199. /// @return 0: SUCCESS / others: FAILED
  200. ///
  201. Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) {
  202. GELOGD("Start create root graph of %s.", graph->GetName().c_str());
  203. uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size();
  204. if (data_count_from_getnext_ != 0) {
  205. input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode;
  206. }
  207. uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize();
  208. OpDescBuilder op_builder(kMultiBatchCaseNode, CASE);
  209. op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num);
  210. const OpDescPtr op_desc = op_builder.Build();
  211. if (op_desc == nullptr) {
  212. GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed");
  213. return OUT_OF_MEMORY;
  214. }
  215. op_desc->RegisterSubgraphIrName("branches", kDynamic);
  216. case_node_ = graph->AddNode(op_desc);
  217. if (case_node_ == nullptr) {
  218. GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed");
  219. return OUT_OF_MEMORY;
  220. }
  221. uint32_t batch_num = static_cast<uint32_t>(batch_shapes_.size());
  222. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  223. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_NUM failed, Case: %s.", op_desc->GetName().c_str());
  224. return FAILED;
  225. }
  226. for (uint32_t i = 0; i < batch_num; i++) {
  227. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  228. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) {
  229. GELOGE(FAILED, "Set attr ATTR_NAME_PRED_VALUE failed, Case: %s.", op_desc->GetName().c_str());
  230. return FAILED;
  231. }
  232. }
  233. std::vector<std::string> data_name_order;
  234. for (auto &item : GetLocalOmgContext().user_input_dims) {
  235. data_name_order.push_back(item.first);
  236. }
  237. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order)) {
  238. GELOGE(FAILED, "Failed to add user designate shape order attr on case node %s",
  239. op_desc->GetName().c_str());
  240. return FAILED;
  241. }
  242. if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) {
  243. GELOGE(INTERNAL_ERROR, "Failed to add insert attr on case node %s", op_desc->GetName().c_str());
  244. return INTERNAL_ERROR;
  245. }
  246. GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed");
  247. GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed");
  248. GE_CHK_STATUS_RET(CreateInputNode(graph), "Create input node failed");
  249. GE_CHK_STATUS_RET(CreateConstNode(graph), "Create const node failed");
  250. GE_CHK_STATUS_RET(CreateOutputNode(graph), "Create output node failed");
  251. return SUCCESS;
  252. }
  253. ///
  254. /// @ingroup ge
  255. /// @brief Create index data node for root graph.
  256. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  257. /// @param [in] NodePtr node: index data node.
  258. /// @return 0: SUCCESS / others: FAILED
  259. ///
  260. Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) {
  261. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA);
  262. if (data_desc == nullptr) {
  263. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  264. return FAILED;
  265. }
  266. GeTensorDesc data_tensor(GeShape({static_cast<int64_t>(batch_shapes_[0].size())}), FORMAT_ND, DT_INT32);
  267. if (data_desc->AddInputDesc(data_tensor) != GRAPH_SUCCESS) {
  268. GELOGE(FAILED, "Add input desc failed");
  269. return FAILED;
  270. }
  271. if (data_desc->AddOutputDesc(data_tensor) != GRAPH_SUCCESS) {
  272. GELOGE(FAILED, "Add output desc failed");
  273. return FAILED;
  274. }
  275. size_t data_index = all_data_nodes_.size();
  276. data_index = data_count_from_getnext_ != 0 ? data_index - kNumOfGetnextNode : data_index;
  277. (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index);
  278. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  279. shape_node = graph->AddNode(data_desc);
  280. if (shape_node == nullptr) {
  281. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  282. return OUT_OF_MEMORY;
  283. }
  284. return SUCCESS;
  285. }
  286. ///
  287. /// @ingroup ge
  288. /// @brief Create index const node for root graph.
  289. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  290. /// @param [in] NodePtr node: index const node.
  291. /// @return 0: SUCCESS / others: FAILED
  292. ///
  293. Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, NodePtr &node) {
  294. const OpDescPtr const_desc = MakeShared<OpDesc>(kMultiBatchConstNode, CONSTANT);
  295. if (const_desc == nullptr) {
  296. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  297. return FAILED;
  298. }
  299. int64_t count = batch_shapes_.size() * batch_shapes_[0].size();
  300. std::unique_ptr<int32_t[]> addr(new (std::nothrow) int32_t[count]);
  301. GE_CHECK_NOTNULL(addr);
  302. size_t i = 0;
  303. for (auto &batch_shape : batch_shapes_) {
  304. for (int64_t dim : batch_shape) {
  305. addr[i++] = static_cast<int32_t>(dim);
  306. }
  307. }
  308. GeTensorDesc const_tensor(GeShape({count}), FORMAT_ND, DT_INT32);
  309. GeTensor tensor(const_tensor);
  310. (void)tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(int32_t));
  311. if (!AttrUtils::SetTensor(const_desc, ATTR_NAME_WEIGHTS, tensor)) {
  312. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", const_desc->GetName().c_str());
  313. return FAILED;
  314. }
  315. if (const_desc->AddOutputDesc(const_tensor) != GRAPH_SUCCESS) {
  316. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", const_desc->GetName().c_str());
  317. return FAILED;
  318. }
  319. node = graph->AddNode(const_desc);
  320. if (node == nullptr) {
  321. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  322. return OUT_OF_MEMORY;
  323. }
  324. return SUCCESS;
  325. }
  326. ///
  327. /// @ingroup ge
  328. /// @brief Create index node for root graph.
  329. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  330. /// @return 0: SUCCESS / others: FAILED
  331. ///
  332. Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) {
  333. // Data/GetDynamicDims --> MapIndex --> Case
  334. if (!getnext_sink_dynamic_dims_) {
  335. GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_), "Create data node failed");
  336. } else {
  337. GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_), "Create get dynamic dims node failed");
  338. }
  339. NodePtr const_node;
  340. GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed");
  341. GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(),
  342. shape_node_->GetType().c_str(), const_node->GetName().c_str());
  343. OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex");
  344. op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0))
  345. .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0))
  346. .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32));
  347. const OpDescPtr op_desc = op_builder.Build();
  348. if (op_desc == nullptr) {
  349. GELOGE(OUT_OF_MEMORY, "Create multi-batch index desc failed");
  350. return FAILED;
  351. }
  352. NodePtr index_node = graph->AddNode(op_desc);
  353. if (index_node == nullptr) {
  354. GELOGE(OUT_OF_MEMORY, "Create multi-batch index node failed");
  355. return OUT_OF_MEMORY;
  356. }
  357. GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "Failed to add attr for %s.",
  358. shape_node_->GetName().c_str());
  359. if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  360. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", shape_node_->GetName().c_str(),
  361. index_node->GetName().c_str());
  362. return FAILED;
  363. }
  364. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
  365. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", const_node->GetName().c_str(),
  366. index_node->GetName().c_str());
  367. return FAILED;
  368. }
  369. if (GraphUtils::AddEdge(index_node->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  370. GELOGE(FAILED, "Failed to add edge between MapIndex:%s to Case:%s", index_node->GetName().c_str(),
  371. case_node_->GetName().c_str());
  372. return FAILED;
  373. }
  374. return SUCCESS;
  375. }
  376. Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) {
  377. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS);
  378. if (data_desc == nullptr) {
  379. GELOGE(OUT_OF_MEMORY, "Create multi-batch get dynamic dims node failed");
  380. return OUT_OF_MEMORY;
  381. }
  382. // input of GetDynamicDims is shape_of_each_data, output is gear_info
  383. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  384. size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size();
  385. // add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter
  386. if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  387. GeTensorDesc tensor_desc;
  388. tensor_desc.SetFormat(FORMAT_ND);
  389. tensor_desc.SetDataType(DT_INT32);
  390. auto ret = data_desc->AddInputDesc(tensor_desc);
  391. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  392. return FAILED);
  393. continue;
  394. }
  395. GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(input_shape_dims)}), FORMAT_ND, DT_INT32);
  396. auto ret = data_desc->AddInputDesc(tensor_desc);
  397. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  398. return FAILED);
  399. }
  400. GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32);
  401. auto ret = data_desc->AddOutputDesc(tensor_desc);
  402. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  403. return FAILED);
  404. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  405. shape_node = graph->AddNode(data_desc);
  406. if (shape_node == nullptr) {
  407. GELOGE(OUT_OF_MEMORY, "Create multi-batch dynamic dims node failed");
  408. return OUT_OF_MEMORY;
  409. }
  410. return SUCCESS;
  411. }
  412. Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) {
  413. if (!getnext_sink_dynamic_dims_) {
  414. GELOGD("No need to add attr when not insert get dynamic dims node.");
  415. return SUCCESS;
  416. }
  417. GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str());
  418. if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) {
  419. GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed");
  420. return INTERNAL_ERROR;
  421. }
  422. vector<int64_t> shape_info;
  423. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  424. if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 &&
  425. GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  426. shape_info.emplace_back(0);
  427. continue;
  428. }
  429. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size());
  430. for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) {
  431. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j));
  432. }
  433. }
  434. if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) {
  435. GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed");
  436. return INTERNAL_ERROR;
  437. }
  438. return SUCCESS;
  439. }
  440. Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) {
  441. GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str());
  442. size_t input_index = 0;
  443. size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst;
  444. for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index,
  445. ++input_index) {
  446. GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index,
  447. shape_node->GetName().c_str(), input_index);
  448. auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index);
  449. auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index));
  450. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s",
  451. getnext_node->GetName().c_str(), shape_node->GetName().c_str());
  452. return INTERNAL_ERROR);
  453. }
  454. return SUCCESS;
  455. }
  456. Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) {
  457. if (!GetLocalOmgContext().dynamic_node_type.empty()) {
  458. if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) {
  459. GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", output_node->GetName().c_str());
  460. return INTERNAL_ERROR;
  461. }
  462. }
  463. if (getnext_sink_dynamic_dims_) {
  464. GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str());
  465. size_t input_index = output_node->GetAllInDataAnchors().size();
  466. if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) {
  467. GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", output_node->GetName().c_str(), input_index);
  468. return INTERNAL_ERROR;
  469. }
  470. auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex),
  471. output_node->GetInDataAnchor(input_index));
  472. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s",
  473. output_node->GetName().c_str(), shape_node_->GetName().c_str());
  474. return INTERNAL_ERROR);
  475. if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) {
  476. GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.",
  477. output_node->GetName().c_str());
  478. return INTERNAL_ERROR;
  479. }
  480. }
  481. return SUCCESS;
  482. }
  483. ///
  484. /// @ingroup ge
  485. /// @brief Create input node for root graph.
  486. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  487. /// @return 0: SUCCESS / others: FAILED
  488. ///
  489. Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) {
  490. // Data --> Case
  491. std::vector<NodePtr> all_data_nodes;
  492. size_t case_input_index = kCaseArgIndex;
  493. NodePtr getnext_node = nullptr;
  494. size_t input_index_of_getnext = 0;
  495. for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) {
  496. const auto &node = all_data_nodes_[i];
  497. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  498. if (op_desc == nullptr) {
  499. GELOGE(OUT_OF_MEMORY, "Create multi-batch Data node failed, name: %s", node->GetName().c_str());
  500. return FAILED;
  501. }
  502. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  503. return FAILED;
  504. }
  505. op_desc->SetName(node->GetName());
  506. const NodePtr &data = graph->AddNode(op_desc);
  507. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  508. if (IsGetNextType(node)) {
  509. getnext_node = data;
  510. input_index_of_getnext = case_input_index;
  511. case_input_index = case_input_index + data_count_from_getnext_;
  512. continue;
  513. } else {
  514. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) !=
  515. GRAPH_SUCCESS) {
  516. GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(),
  517. case_node_->GetName().c_str());
  518. return FAILED;
  519. }
  520. }
  521. if (SetMaxShape(data) != SUCCESS) {
  522. GELOGE(FAILED, "Set max shape of %s failed.", data->GetName().c_str());
  523. return FAILED;
  524. }
  525. all_data_nodes.emplace_back(data);
  526. }
  527. if (getnext_node != nullptr) {
  528. if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) {
  529. GELOGE(FAILED, "Failed to link edge for %s.", getnext_node->GetName().c_str());
  530. return FAILED;
  531. }
  532. if (SetMaxShape(getnext_node) != SUCCESS) {
  533. GELOGE(FAILED, "Set max shape of %s failed.", getnext_node->GetName().c_str());
  534. return FAILED;
  535. }
  536. all_data_nodes.emplace_back(getnext_node);
  537. }
  538. all_data_nodes_.swap(all_data_nodes);
  539. return SUCCESS;
  540. }
  541. Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) {
  542. GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(),
  543. case_input_index, case_node_->GetName().c_str());
  544. for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) {
  545. if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index),
  546. case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) {
  547. GELOGE(FAILED, "Failed to add data edge between %zu Data:%s to %zu Case:%s", out_index,
  548. getnext_node->GetName().c_str(), case_input_index, case_node_->GetName().c_str());
  549. return FAILED;
  550. }
  551. }
  552. if (getnext_sink_dynamic_dims_) {
  553. GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "Failed to add link for %s.",
  554. shape_node_->GetName().c_str());
  555. }
  556. return SUCCESS;
  557. }
  558. ///
  559. /// @ingroup ge
  560. /// @brief Create Const node for root graph.
  561. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  562. /// @return 0: SUCCESS / others: FAILED
  563. ///
  564. Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) {
  565. // Const --> Case
  566. std::vector<NodePtr> all_const_nodes;
  567. size_t arg_index = kCaseArgIndex + all_data_nodes_.size();
  568. if (data_count_from_getnext_ != 0) {
  569. arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode;
  570. }
  571. for (size_t i = 0; i < all_const_nodes_.size(); ++i) {
  572. const auto &node = all_const_nodes_[i];
  573. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  574. if (op_desc == nullptr) {
  575. GELOGE(OUT_OF_MEMORY, "Create multi-batch Const node failed, name: %s", node->GetName().c_str());
  576. return FAILED;
  577. }
  578. op_desc->SetName(node->GetName());
  579. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  580. return FAILED;
  581. }
  582. const NodePtr &data = graph->AddNode(op_desc);
  583. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  584. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  585. GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(),
  586. case_node_->GetName().c_str());
  587. return FAILED;
  588. }
  589. all_const_nodes.emplace_back(data);
  590. }
  591. ChangeConstToData();
  592. all_const_nodes_.swap(all_const_nodes);
  593. return SUCCESS;
  594. }
  595. void MultiBatchClonePass::ChangeConstToData() {
  596. size_t data_index = all_data_nodes_.size();
  597. if (data_count_from_getnext_ != 0) {
  598. data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode;
  599. }
  600. for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data.
  601. auto &const_node = all_const_nodes_[i];
  602. bool need_change_type = true;
  603. if (out_control_nodes_.find(const_node) != out_control_nodes_.end()) {
  604. GELOGD("No need to change %s to data type.", const_node->GetName().c_str());
  605. need_change_type = false;
  606. break;
  607. }
  608. if (!need_change_type) {
  609. continue;
  610. }
  611. const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc();
  612. op_desc->SetType(DATA);
  613. (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight.
  614. // Const no InputDesc, Data need InputDesc.
  615. (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex));
  616. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  617. (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1);
  618. }
  619. }
  620. ///
  621. /// @ingroup ge
  622. /// @brief Create output node for root graph.
  623. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  624. /// @return 0: SUCCESS / others: FAILED
  625. ///
  626. Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) {
  627. const auto &output = all_output_nodes_[0];
  628. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc());
  629. if (op_desc == nullptr) {
  630. GELOGE(OUT_OF_MEMORY, "Create multi-batch output node failed");
  631. return FAILED;
  632. }
  633. if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) {
  634. return FAILED;
  635. }
  636. op_desc->SetName(output->GetName());
  637. const NodePtr &node = graph->AddNode(op_desc);
  638. GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  639. for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) {
  640. const auto it = direct_output_.find(i);
  641. if (it == direct_output_.end()) {
  642. if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  643. GELOGE(FAILED, "Failed to add edge between Case:%s to NetOutput:%s",
  644. case_node_->GetName().c_str(), node->GetName().c_str());
  645. return FAILED;
  646. }
  647. } else {
  648. const auto data_node = graph->FindNode(it->second);
  649. if (data_node == nullptr) {
  650. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Data node:%s not found", it->second.c_str());
  651. return GE_GRAPH_GRAPH_NODE_NULL;
  652. }
  653. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  654. GELOGE(FAILED, "Failed to add edge between Data:%s to NetOutput:%s",
  655. data_node->GetName().c_str(), node->GetName().c_str());
  656. return FAILED;
  657. }
  658. }
  659. }
  660. GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "Failed to add edge between %s to netoutput: %s.",
  661. shape_node_->GetName().c_str(), output->GetName().c_str());
  662. all_output_nodes_.clear();
  663. all_output_nodes_.emplace_back(node);
  664. return SUCCESS;
  665. }
  666. ///
  667. /// @ingroup ge
  668. /// @brief Set max shape to Data node in root graph.
  669. /// @param [in] const NodePtr &data: data in Root/Case graph.
  670. /// @return 0: SUCCESS / others: FAILED
  671. ///
  672. Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) {
  673. GELOGD("Start set max shape for %s.", data->GetName().c_str());
  674. if (!IsGetNextType(data)) {
  675. if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) {
  676. GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str());
  677. return PARAM_INVALID;
  678. }
  679. } else {
  680. for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) {
  681. if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) {
  682. GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str());
  683. return PARAM_INVALID;
  684. }
  685. }
  686. }
  687. return SUCCESS;
  688. }
  689. Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) {
  690. GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index);
  691. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  692. string data_name = node->GetName();
  693. if (IsGetNextType(node)) {
  694. data_name.append("_").append(std::to_string(out_anchor_index));
  695. }
  696. GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(),
  697. formats::JoinToString(data_shape.GetDims()).c_str());
  698. const auto &dims = data_shape.GetDims();
  699. if (!IsGetNextType(node)) {
  700. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  701. GELOGD("No need to do anything for static data.");
  702. return SUCCESS;
  703. }
  704. } else {
  705. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  706. if (getnext_sink_dynamic_dims_) {
  707. // need to update shape of Shape_node when getnext node has dynamic data
  708. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node");
  709. }
  710. return SUCCESS;
  711. }
  712. }
  713. (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  714. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex));
  715. std::vector<std::string> input_dims_str;
  716. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  717. auto shape = data_shape;
  718. auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  719. if (ret != SUCCESS) {
  720. GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", node->GetName().c_str());
  721. return ret;
  722. }
  723. tensor.SetShape(shape);
  724. int64_t tensor_size = 0;
  725. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  726. string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  727. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" +
  728. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  729. formats::JoinToString(tensor.GetShape().GetDims());
  730. input_dims_str.emplace_back(input_str);
  731. }
  732. (void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  733. size_t max_shape_index = 0;
  734. int64_t max_size = 0;
  735. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  736. int64_t size = 1;
  737. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  738. if (INT64_MAX / dim < size) {
  739. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  740. formats::ShapeToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  741. return PARAM_INVALID;
  742. }
  743. size *= dim;
  744. }
  745. if (size > max_size) {
  746. max_size = size;
  747. max_shape_index = i;
  748. }
  749. }
  750. return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index);
  751. }
  752. ///
  753. /// @ingroup ge
  754. /// @brief Set max shape to Data/GetNext node in root graph.
  755. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  756. /// @param [in] const NodePtr &data: data in Root/Case graph.
  757. /// @param [in] GeShape &data_shape: dims of data node.
  758. /// @param [in] size_t out_anchor_index: out anchor index of data node.
  759. /// @return 0: SUCCESS / others: FAILED
  760. ///
  761. Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape,
  762. size_t out_anchor_index) {
  763. GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str());
  764. if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) {
  765. GELOGE(INTERNAL_ERROR, "Failed to calculate the batched shape for data node %s, the shapes may not match",
  766. data->GetName().c_str());
  767. return INTERNAL_ERROR;
  768. }
  769. if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) {
  770. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  771. return INTERNAL_ERROR;
  772. }
  773. if (!IsGetNextType(data)) {
  774. if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) {
  775. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  776. return INTERNAL_ERROR;
  777. }
  778. } else {
  779. if (getnext_sink_dynamic_dims_) {
  780. // need to update shape of Shape_node when getnext_sink_dynamic
  781. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index), "Failed to update shape of shape node");
  782. }
  783. }
  784. GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(),
  785. formats::ShapeToString(data_shape).c_str());
  786. return SUCCESS;
  787. }
  788. Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) {
  789. GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index,
  790. node->GetName().c_str());
  791. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  792. size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst);
  793. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index);
  794. std::vector<int64_t> output_dims = {static_cast<int64_t>(data_shape.GetDims().size())};
  795. GeShape output_shape(output_dims);
  796. output_desc.SetShape(output_shape);
  797. if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) {
  798. GELOGE(FAILED, "Update output desc fail.");
  799. return FAILED;
  800. }
  801. return SUCCESS;
  802. }
  803. ///
  804. /// @ingroup ge
  805. /// @brief Update Data node in Subgraph.
  806. /// @param [in] const NodePtr &data: data in Subgraph.
  807. /// @param [in] size_t batch_index: The batch index.
  808. /// @return 0: SUCCESS / others: FAILED
  809. ///
  810. Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) {
  811. int node_index = -1;
  812. if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) {
  813. GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str());
  814. return FAILED;
  815. }
  816. int parent_index = node_index + 1;
  817. if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  818. GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str());
  819. return FAILED;
  820. }
  821. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  822. const auto &dims = data_shape.GetDims();
  823. GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index,
  824. formats::JoinToString(dims).c_str());
  825. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  826. return SUCCESS;
  827. }
  828. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  829. auto data_name = data->GetName();
  830. size_t pos = data_name.find(kMultiBatchNodePostfix);
  831. if (pos == string::npos) {
  832. GELOGE(FAILED, "Cannot find key string [%s] of multi-batch in name of virtual input node, node name: %s.",
  833. kMultiBatchNodePostfix.c_str(), data_name.c_str());
  834. return FAILED;
  835. }
  836. auto parent_name = data_name.substr(0, pos);
  837. return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex);
  838. }
  839. Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) {
  840. if (data_count_from_getnext_ == 0) {
  841. GELOGD("No need to change original graph without getnext node.");
  842. return SUCCESS;
  843. }
  844. GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str());
  845. size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode;
  846. for (const auto &node : graph->GetDirectNode()) {
  847. if (IsGetNextType(node)) {
  848. for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) {
  849. auto out_data_anchor = node->GetOutDataAnchor(out_index);
  850. GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue);
  851. NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index);
  852. GE_IF_BOOL_EXEC(data_node == nullptr, GELOGE(INTERNAL_ERROR, "Create %zu data node failed.",
  853. out_data_anchor->GetIdx()); return INTERNAL_ERROR);
  854. for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  855. GE_IF_BOOL_EXEC(in_anchor == nullptr, continue);
  856. NodePtr dst_node = in_anchor->GetOwnerNode();
  857. if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) {
  858. GELOGE(INTERNAL_ERROR, "Failed to remove edge between %s to %s", node->GetName().c_str(),
  859. dst_node->GetName().c_str());
  860. return INTERNAL_ERROR;
  861. }
  862. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) !=
  863. GRAPH_SUCCESS) {
  864. GELOGE(INTERNAL_ERROR, "Failed to add edge between %s to %s", data_node->GetName().c_str(),
  865. dst_node->GetName().c_str());
  866. return INTERNAL_ERROR;
  867. }
  868. }
  869. }
  870. if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
  871. GELOGE(GRAPH_FAILED, "Remove node %s failed!", node->GetName().c_str());
  872. return GRAPH_FAILED;
  873. }
  874. break;
  875. }
  876. }
  877. return SUCCESS;
  878. }
  879. NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor,
  880. size_t data_index) {
  881. size_t out_anchor_index = out_data_anchor->GetIdx();
  882. std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index);
  883. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, DATA);
  884. if (op_desc == nullptr) {
  885. GELOGE(OUT_OF_MEMORY, "Create data node failed.");
  886. return nullptr;
  887. }
  888. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  889. OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
  890. if (getnext_op_desc == nullptr) {
  891. GELOGE(OUT_OF_MEMORY, "Op desc of %s is nullptr.", out_data_anchor->GetOwnerNode()->GetName().c_str());
  892. return nullptr;
  893. }
  894. if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) {
  895. GELOGE(INTERNAL_ERROR, "Add %s input desc failed.", op_desc->GetName().c_str());
  896. return nullptr;
  897. }
  898. if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) {
  899. GELOGE(INTERNAL_ERROR, "Add %s output desc failed.", op_desc->GetName().c_str());
  900. return nullptr;
  901. }
  902. NodePtr data_node = graph->AddNode(op_desc);
  903. GELOGD("Success create %s node.", data_node->GetName().c_str());
  904. return data_node;
  905. }
  906. ///
  907. /// @ingroup ge
  908. /// @brief Create nodes for root graph.
  909. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  910. /// @param [in] const ComputeGraphPtr &branch: original graph.
  911. /// @return 0: SUCCESS / others: FAILED
  912. ///
  913. Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) {
  914. GELOGD("Start create subgraphs for %s.", graph->GetName().c_str());
  915. const auto &op_desc = case_node_->GetOpDesc();
  916. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  917. std::vector<NodePtr> input_nodes;
  918. std::vector<NodePtr> output_nodes;
  919. const std::string postfix = kMultiBatchNodePostfix + std::to_string(i);
  920. ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes);
  921. GE_IF_BOOL_EXEC(subgraph == nullptr, GELOGE(FAILED, "Create multi-batch case node failed"); return FAILED);
  922. subgraph->SetName("Batch_" + std::to_string(i));
  923. subgraph->SetParentNode(case_node_);
  924. subgraph->SetParentGraph(graph);
  925. graph->AddSubgraph(subgraph->GetName(), subgraph);
  926. all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  927. GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]),
  928. "Update %s failed", all_branch_output_[subgraph]->GetName().c_str());
  929. const string key_name = "branches" + std::to_string(i);
  930. op_desc->AddSubgraphName(key_name);
  931. op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
  932. GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size());
  933. for (const auto &data : input_nodes) {
  934. GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str());
  935. }
  936. }
  937. // Origninal graph take as first subgraph, update node name.
  938. for (const auto &n : branch->GetDirectNode()) {
  939. const auto &op_desc = n->GetOpDesc();
  940. op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
  941. if (n->GetType() == DATA) {
  942. GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str());
  943. }
  944. }
  945. return SUCCESS;
  946. }
  947. ///
  948. /// @ingroup ge
  949. /// @brief Update output_node in Subgraph.
  950. /// @param [in] const NodePtr &output_node: output_node in Subgraph.
  951. /// @return 0: SUCCESS / others: FAILED
  952. ///
  953. Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) {
  954. const auto &op_desc = output_node->GetOpDesc();
  955. GE_CHECK_NOTNULL(op_desc);
  956. for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
  957. GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
  958. GE_CHECK_NOTNULL(tensor);
  959. if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
  960. GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
  961. return FAILED;
  962. }
  963. }
  964. return SUCCESS;
  965. }
  966. ///
  967. /// @ingroup ge
  968. /// @brief Remove subgraph suspend output anchor.
  969. /// @param [in] ComputeGraphPtr &graph: Parent compute graph.
  970. /// @return 0: SUCCESS / others: FAILED
  971. ///
  972. Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) {
  973. GELOGD("Start prune direct output.");
  974. const auto &func_desc = case_node_->GetOpDesc();
  975. uint32_t unused_num = 0;
  976. uint32_t output_num = func_desc->GetOutputsSize();
  977. for (size_t i = 0; i < output_num; ++i) {
  978. bool is_unused_tensor = true;
  979. for (const auto &item : all_branch_output_) {
  980. const auto &netoutput = item.second;
  981. GE_CHECK_NOTNULL(netoutput);
  982. const auto in_anchor = netoutput->GetInDataAnchor(i);
  983. if (in_anchor->GetPeerOutAnchor() != nullptr) {
  984. is_unused_tensor = false;
  985. break;
  986. }
  987. }
  988. if (is_unused_tensor) {
  989. unused_num++;
  990. continue;
  991. }
  992. GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num), "Graph:%s Update output failed", graph->GetName().c_str());
  993. }
  994. if (unused_num == 0) {
  995. return SUCCESS;
  996. }
  997. GE_CHK_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num), "Remove output failed");
  998. for (const auto &item : all_branch_output_) {
  999. GE_CHK_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num), "Remove input failed");
  1000. }
  1001. return SUCCESS;
  1002. }
  1003. ///
  1004. /// @ingroup ge
  1005. /// @brief Update subgraph suspend output tensor.
  1006. /// @param [in] parent_index: parent index for check.
  1007. /// @param [in] unused_num: total unused tensor.
  1008. /// @return 0: SUCCESS / others: FAILED
  1009. ///
  1010. Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) {
  1011. if (unused_num == 0) {
  1012. GELOGD("No need to update output tensor.");
  1013. return SUCCESS;
  1014. }
  1015. uint32_t update_index = parent_index - unused_num;
  1016. for (const auto &item : all_branch_output_) {
  1017. const auto &node = item.second;
  1018. const auto &new_anchor = node->GetInDataAnchor(update_index);
  1019. const auto &old_anchor = node->GetInDataAnchor(parent_index);
  1020. const auto &out_anchor = old_anchor->GetPeerOutAnchor();
  1021. const auto &out_node = out_anchor->GetOwnerNode();
  1022. const auto &op_desc = node->GetOpDesc();
  1023. (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index));
  1024. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed");
  1025. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  1026. case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index);
  1027. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed");
  1028. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str());
  1029. }
  1030. const auto &new_anchor = case_node_->GetOutDataAnchor(update_index);
  1031. const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index);
  1032. for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) {
  1033. const auto &in_node = in_anchor->GetOwnerNode();
  1034. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor), "Remove edge failed");
  1035. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str());
  1036. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor), "Add edge failed");
  1037. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  1038. case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index);
  1039. }
  1040. return SUCCESS;
  1041. }
  1042. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示