You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_clone_pass.cc 52 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_clone_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/ge/ge_util.h"
  19. #include "graph/common/local_context.h"
  20. #include "graph/preprocess/multi_batch_options.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "graph/utils/tensor_utils.h"
  24. #include "graph/utils/type_utils.h"
  25. #include "register/op_registry.h"
  26. #include "graph/common/omg_util.h"
  27. namespace ge {
  28. namespace {
  29. constexpr uint8_t kDataInIndex = 0;
  30. constexpr uint8_t kDataOutIndex = 0;
  31. constexpr uint8_t kCaseArgIndex = 1;
  32. const int kDivisionConst = 2;
  33. const size_t kNumOfGetnextNode = 1;
  34. const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case";
  35. const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data";
  36. const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node";
  37. const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
  38. const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
  39. const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_";
  40. const char *const kGetNextName = "IteratorV2";
  41. } // namespace
  42. inline bool IsGetNextType(const NodePtr &node) {
  43. std::string original_type;
  44. GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
  45. GELOGW("Get original type failed."); return false);
  46. return (original_type == kGetNextName);
  47. }
  48. Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
  49. GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(FAILED, "Original graph is nullptr"); return FAILED);
  50. if (graph->GetParentGraph() != nullptr) {
  51. GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str());
  52. return SUCCESS;
  53. }
  54. if (!GetLocalOmgContext().need_multi_batch) {
  55. GELOGI("No need to process_multi for no_train graph.");
  56. return SUCCESS;
  57. }
  58. std::vector<NodePtr> data_nodes;
  59. std::vector<NodePtr> getnext_nosink_nodes;
  60. std::vector<NodePtr> getnext_sink_nodes;
  61. if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  62. GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed.");
  63. return PARAM_INVALID;
  64. }
  65. if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  66. GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed.");
  67. return PARAM_INVALID;
  68. }
  69. if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) {
  70. GELOGE(PARAM_INVALID, "[Train_Dynamic] DeleteIdentityInsertByAdapter failed.");
  71. return PARAM_INVALID;
  72. }
  73. if (!multibatch::InitDynamicParams(batch_shapes_)) {
  74. GELOGD("There is no multi-batch options, no need clone multi-batch graph");
  75. return SUCCESS;
  76. }
  77. if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) {
  78. GELOGE(PARAM_INVALID, "[Train_Dynamic] Input_shape and dynamic_dims should set correct params.");
  79. return PARAM_INVALID;
  80. }
  81. GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str());
  82. GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param");
  83. if (CollectIoNodes(graph) != SUCCESS) {
  84. GELOGE(INTERNAL_ERROR, "Collect input output nodes failed");
  85. return INTERNAL_ERROR;
  86. }
  87. // parser data dynamic info from atc parameter --input_shape
  88. if (CheckAndParseDynamicData() != SUCCESS) {
  89. GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed");
  90. return PARAM_INVALID;
  91. }
  92. (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  93. ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName());
  94. GE_IF_BOOL_EXEC(branch == nullptr, GELOGE(OUT_OF_MEMORY, "Create multi batch graph failed"); return OUT_OF_MEMORY);
  95. (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  96. graph->InValid(); // Will modify, need topological again.
  97. graph->Swap(*branch);
  98. GE_CHK_STATUS_RET(CreateRootGraph(graph), "Construct root graph failed.");
  99. GE_CHK_STATUS_RET(CreateOriGraph(branch), "Construct original graph failed.")
  100. GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed.");
  101. GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed");
  102. GELOGD("MultiBatchClonePass Leave");
  103. return SUCCESS;
  104. }
  105. ///
  106. /// @ingroup ge
  107. /// @brief Collect input output node from original graph.
  108. /// @param [in] const ComputeGraphPtr &graph: original graph.
  109. /// @return 0: SUCCESS / others: FAILED
  110. ///
  111. Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
  112. for (const auto &node : graph->GetDirectNode()) {
  113. if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) {
  114. all_data_nodes_.emplace_back(node);
  115. GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "Init params of %s failed.", node->GetName().c_str());
  116. }
  117. if (node->GetType() == DATA) {
  118. all_data_nodes_.emplace_back(node);
  119. } else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) {
  120. all_const_nodes_.emplace_back(node);
  121. } else if (node->GetType() == NETOUTPUT) {
  122. all_output_nodes_.emplace_back(node);
  123. }
  124. // If the node save as input/output node, delete record.
  125. (void)graph->RemoveInputNode(node);
  126. (void)graph->RemoveOutputNode(node);
  127. }
  128. if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) {
  129. GELOGE(FAILED, "data nodes: %zu, output nodes: %zu", all_data_nodes_.size(), all_output_nodes_.size());
  130. return FAILED;
  131. }
  132. int64_t data_index = 0;
  133. size_t getnext_node_count = 0;
  134. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  135. if (IsGetNextType(all_data_nodes_[i])) {
  136. // just one getnext node in graph
  137. getnext_node_count++;
  138. continue;
  139. }
  140. const auto &op_desc = all_data_nodes_[i]->GetOpDesc();
  141. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) {
  142. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count);
  143. }
  144. }
  145. const auto &output = all_output_nodes_[0];
  146. for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) {
  147. const auto in_anchor = output->GetInDataAnchor(i);
  148. const auto out_anchor = in_anchor->GetPeerOutAnchor();
  149. const auto data_node = out_anchor->GetOwnerNode();
  150. if (data_node->GetType() == DATA) {
  151. direct_output_[i] = data_node->GetName();
  152. GE_CHK_GRAPH_STATUS_RET(
  153. GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex), output->GetInDataAnchor(i)),
  154. "Remove edge failed");
  155. }
  156. }
  157. GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.",
  158. all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(),
  159. direct_output_.size());
  160. return SUCCESS;
  161. }
  162. Status MultiBatchClonePass::CheckAndParseDynamicData(){
  163. size_t unknown_shape_count = 0;
  164. auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
  165. std::vector<std::string> data_name_order;
  166. for (auto &item : GetLocalOmgContext().user_input_dims) {
  167. data_name_order.push_back(item.first);
  168. }
  169. if (!getnext_sink_dynamic_dims_) {
  170. for (const auto &node : all_data_nodes_) {
  171. auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
  172. auto data_shape = data_desc.GetShape();
  173. auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
  174. data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
  175. auto data_name = node->GetName();
  176. GELOGI("CheckAndParseDynamicData shape_dims is %s.", formats::JoinToString(data_shape.GetDims()).c_str());
  177. std::vector<int64_t> data_shape_dims = data_shape.GetDims();
  178. if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) {
  179. continue;
  180. }
  181. ++unknown_shape_count;
  182. auto iter = find(data_name_order.begin(), data_name_order.end(), data_name);
  183. if (iter == data_name_order.end()) {
  184. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  185. auto ret = multibatch::CheckDynamicBatchShape(data_shape_dims, data_name);
  186. GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic batch shape of %s.",
  187. data_name.c_str()); return PARAM_INVALID);
  188. } else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  189. auto ret = multibatch::CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format);
  190. GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.",
  191. data_name.c_str()); return PARAM_INVALID);
  192. } else if (!GetLocalOmgContext().dynamic_dims.empty()) {
  193. ErrorManager::GetInstance().ATCReportErrMessage("E10001",{"parameter", "reason"},
  194. {"--input_shape","all dynamic data must be set in --input_shape"});
  195. GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape",
  196. node->GetName().c_str(), data_shape.ToString().c_str());
  197. return INTERNAL_ERROR;
  198. }
  199. data_name_and_shape.emplace_back(data_name, data_shape_dims);
  200. }
  201. }
  202. }
  203. auto ret = multibatch::ParserDataToDynmaicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_);
  204. GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info.");
  205. if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) {
  206. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  207. GELOGE(PARAM_INVALID,
  208. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  209. return PARAM_INVALID;
  210. }
  211. return SUCCESS;
  212. }
  213. Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) {
  214. data_count_from_getnext_ = 0;
  215. getnext_sink_dynamic_dims_ = false;
  216. GE_CHECK_NOTNULL(node->GetOpDesc());
  217. data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize();
  218. if (GetLocalOmgContext().dynamic_node_type == GETNEXT) {
  219. data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst;
  220. for (size_t i = 0; i < data_count_from_getnext_; ++i) {
  221. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i);
  222. GELOGD("The %zu data shape from getnext sink is %s.", i,
  223. formats::JoinToString(output_desc.GetShape().GetDims()).c_str());
  224. const auto &dims = output_desc.GetShape().GetDims();
  225. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) {
  226. GELOGD("The %zu data from %s is static.", i, node->GetName().c_str());
  227. } else {
  228. getnext_sink_dynamic_dims_ = true;
  229. GELOGD("Dynamic dims in the pattern of getnext sink.");
  230. }
  231. }
  232. }
  233. if (node->GetOutControlAnchor() != nullptr) {
  234. for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) {
  235. NodePtr next_node = peer_in_control_anchor->GetOwnerNode();
  236. GE_CHECK_NOTNULL(next_node);
  237. if (next_node->GetType() == CONSTANTOP) {
  238. out_control_nodes_.insert(next_node);
  239. GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str());
  240. }
  241. }
  242. }
  243. return SUCCESS;
  244. }
  245. ///
  246. /// @ingroup ge
  247. /// @brief Create nodes for root graph.
  248. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  249. /// @return 0: SUCCESS / others: FAILED
  250. ///
  251. Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) {
  252. GELOGD("Start create root graph of %s.", graph->GetName().c_str());
  253. uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size();
  254. if (data_count_from_getnext_ != 0) {
  255. input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode;
  256. }
  257. uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize();
  258. OpDescBuilder op_builder(kMultiBatchCaseNode, CASE);
  259. op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num);
  260. const OpDescPtr op_desc = op_builder.Build();
  261. if (op_desc == nullptr) {
  262. GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed");
  263. return OUT_OF_MEMORY;
  264. }
  265. op_desc->RegisterSubgraphIrName("branches", kDynamic);
  266. case_node_ = graph->AddNode(op_desc);
  267. if (case_node_ == nullptr) {
  268. GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed");
  269. return OUT_OF_MEMORY;
  270. }
  271. uint32_t batch_num = static_cast<uint32_t>(batch_shapes_.size());
  272. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  273. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_NUM failed, Case: %s.", op_desc->GetName().c_str());
  274. return FAILED;
  275. }
  276. for (uint32_t i = 0; i < batch_num; i++) {
  277. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  278. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) {
  279. GELOGE(FAILED, "Set attr ATTR_NAME_PRED_VALUE failed, Case: %s.", op_desc->GetName().c_str());
  280. return FAILED;
  281. }
  282. }
  283. std::vector<std::string> data_name_order;
  284. for (auto &item : GetLocalOmgContext().user_input_dims) {
  285. data_name_order.push_back(item.first);
  286. }
  287. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order)) {
  288. GELOGE(FAILED, "Failed to add user designate shape order attr on case node %s",
  289. op_desc->GetName().c_str());
  290. return FAILED;
  291. }
  292. if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) {
  293. GELOGE(INTERNAL_ERROR, "Failed to add insert attr on case node %s", op_desc->GetName().c_str());
  294. return INTERNAL_ERROR;
  295. }
  296. GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed");
  297. GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed");
  298. GE_CHK_STATUS_RET(CreateInputNode(graph), "Create input node failed");
  299. GE_CHK_STATUS_RET(CreateConstNode(graph), "Create const node failed");
  300. GE_CHK_STATUS_RET(CreateOutputNode(graph), "Create output node failed");
  301. return SUCCESS;
  302. }
  303. ///
  304. /// @ingroup ge
  305. /// @brief Create index data node for root graph.
  306. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  307. /// @param [in] NodePtr node: index data node.
  308. /// @return 0: SUCCESS / others: FAILED
  309. ///
  310. Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) {
  311. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA);
  312. if (data_desc == nullptr) {
  313. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  314. return FAILED;
  315. }
  316. GeTensorDesc data_tensor(GeShape({static_cast<int64_t>(batch_shapes_[0].size())}), FORMAT_ND, DT_INT32);
  317. if (data_desc->AddInputDesc(data_tensor) != GRAPH_SUCCESS) {
  318. GELOGE(FAILED, "Add input desc failed");
  319. return FAILED;
  320. }
  321. if (data_desc->AddOutputDesc(data_tensor) != GRAPH_SUCCESS) {
  322. GELOGE(FAILED, "Add output desc failed");
  323. return FAILED;
  324. }
  325. size_t data_index = all_data_nodes_.size();
  326. data_index = data_count_from_getnext_ != 0 ? data_index - kNumOfGetnextNode : data_index;
  327. (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index);
  328. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  329. shape_node = graph->AddNode(data_desc);
  330. if (shape_node == nullptr) {
  331. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  332. return OUT_OF_MEMORY;
  333. }
  334. return SUCCESS;
  335. }
  336. ///
  337. /// @ingroup ge
  338. /// @brief Create index const node for root graph.
  339. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  340. /// @param [in] NodePtr node: index const node.
  341. /// @return 0: SUCCESS / others: FAILED
  342. ///
  343. Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, NodePtr &node) {
  344. const OpDescPtr const_desc = MakeShared<OpDesc>(kMultiBatchConstNode, CONSTANT);
  345. if (const_desc == nullptr) {
  346. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  347. return FAILED;
  348. }
  349. int64_t count = batch_shapes_.size() * batch_shapes_[0].size();
  350. std::unique_ptr<int32_t[]> addr(new (std::nothrow) int32_t[count]);
  351. GE_CHECK_NOTNULL(addr);
  352. size_t i = 0;
  353. for (auto &batch_shape : batch_shapes_) {
  354. for (int64_t dim : batch_shape) {
  355. addr[i++] = static_cast<int32_t>(dim);
  356. }
  357. }
  358. GeTensorDesc const_tensor(GeShape({count}), FORMAT_ND, DT_INT32);
  359. GeTensor tensor(const_tensor);
  360. (void)tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(int32_t));
  361. if (!AttrUtils::SetTensor(const_desc, ATTR_NAME_WEIGHTS, tensor)) {
  362. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", const_desc->GetName().c_str());
  363. return FAILED;
  364. }
  365. if (const_desc->AddOutputDesc(const_tensor) != GRAPH_SUCCESS) {
  366. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", const_desc->GetName().c_str());
  367. return FAILED;
  368. }
  369. node = graph->AddNode(const_desc);
  370. if (node == nullptr) {
  371. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  372. return OUT_OF_MEMORY;
  373. }
  374. return SUCCESS;
  375. }
  376. ///
  377. /// @ingroup ge
  378. /// @brief Create index node for root graph.
  379. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  380. /// @return 0: SUCCESS / others: FAILED
  381. ///
  382. Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) {
  383. // Data/GetDynamicDims --> MapIndex --> Case
  384. if (!getnext_sink_dynamic_dims_) {
  385. GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_), "Create data node failed");
  386. } else {
  387. GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_), "Create get dynamic dims node failed");
  388. }
  389. NodePtr const_node;
  390. GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed");
  391. GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(),
  392. shape_node_->GetType().c_str(), const_node->GetName().c_str());
  393. OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex");
  394. op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0))
  395. .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0))
  396. .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32));
  397. const OpDescPtr op_desc = op_builder.Build();
  398. if (op_desc == nullptr) {
  399. GELOGE(OUT_OF_MEMORY, "Create multi-batch index desc failed");
  400. return FAILED;
  401. }
  402. NodePtr index_node = graph->AddNode(op_desc);
  403. if (index_node == nullptr) {
  404. GELOGE(OUT_OF_MEMORY, "Create multi-batch index node failed");
  405. return OUT_OF_MEMORY;
  406. }
  407. GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "Failed to add attr for %s.",
  408. shape_node_->GetName().c_str());
  409. if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  410. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", shape_node_->GetName().c_str(),
  411. index_node->GetName().c_str());
  412. return FAILED;
  413. }
  414. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
  415. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", const_node->GetName().c_str(),
  416. index_node->GetName().c_str());
  417. return FAILED;
  418. }
  419. if (GraphUtils::AddEdge(index_node->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  420. GELOGE(FAILED, "Failed to add edge between MapIndex:%s to Case:%s", index_node->GetName().c_str(),
  421. case_node_->GetName().c_str());
  422. return FAILED;
  423. }
  424. return SUCCESS;
  425. }
  426. Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) {
  427. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS);
  428. if (data_desc == nullptr) {
  429. GELOGE(OUT_OF_MEMORY, "Create multi-batch get dynamic dims node failed");
  430. return OUT_OF_MEMORY;
  431. }
  432. // input of GetDynamicDims is shape_of_each_data, output is gear_info
  433. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  434. size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size();
  435. // add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter
  436. if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  437. GeTensorDesc tensor_desc;
  438. tensor_desc.SetFormat(FORMAT_ND);
  439. tensor_desc.SetDataType(DT_INT32);
  440. auto ret = data_desc->AddInputDesc(tensor_desc);
  441. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  442. return FAILED);
  443. continue;
  444. }
  445. GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(input_shape_dims)}), FORMAT_ND, DT_INT32);
  446. auto ret = data_desc->AddInputDesc(tensor_desc);
  447. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  448. return FAILED);
  449. }
  450. GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32);
  451. auto ret = data_desc->AddOutputDesc(tensor_desc);
  452. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  453. return FAILED);
  454. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  455. shape_node = graph->AddNode(data_desc);
  456. if (shape_node == nullptr) {
  457. GELOGE(OUT_OF_MEMORY, "Create multi-batch dynamic dims node failed");
  458. return OUT_OF_MEMORY;
  459. }
  460. return SUCCESS;
  461. }
  462. Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) {
  463. if (!getnext_sink_dynamic_dims_) {
  464. GELOGD("No need to add attr when not insert get dynamic dims node.");
  465. return SUCCESS;
  466. }
  467. GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str());
  468. if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) {
  469. GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed");
  470. return INTERNAL_ERROR;
  471. }
  472. vector<int64_t> shape_info;
  473. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  474. if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 &&
  475. GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  476. shape_info.emplace_back(0);
  477. continue;
  478. }
  479. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size());
  480. for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) {
  481. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j));
  482. }
  483. }
  484. if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) {
  485. GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed");
  486. return INTERNAL_ERROR;
  487. }
  488. return SUCCESS;
  489. }
  490. Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) {
  491. GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str());
  492. size_t input_index = 0;
  493. size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst;
  494. for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index,
  495. ++input_index) {
  496. GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index,
  497. shape_node->GetName().c_str(), input_index);
  498. auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index);
  499. auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index));
  500. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s",
  501. getnext_node->GetName().c_str(), shape_node->GetName().c_str());
  502. return INTERNAL_ERROR);
  503. }
  504. return SUCCESS;
  505. }
  506. Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) {
  507. if (!GetLocalOmgContext().dynamic_node_type.empty()) {
  508. if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) {
  509. GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", output_node->GetName().c_str());
  510. return INTERNAL_ERROR;
  511. }
  512. }
  513. if (getnext_sink_dynamic_dims_) {
  514. GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str());
  515. size_t input_index = output_node->GetAllInDataAnchors().size();
  516. if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) {
  517. GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", output_node->GetName().c_str(), input_index);
  518. return INTERNAL_ERROR;
  519. }
  520. auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex),
  521. output_node->GetInDataAnchor(input_index));
  522. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s",
  523. output_node->GetName().c_str(), shape_node_->GetName().c_str());
  524. return INTERNAL_ERROR);
  525. if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) {
  526. GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.",
  527. output_node->GetName().c_str());
  528. return INTERNAL_ERROR;
  529. }
  530. }
  531. return SUCCESS;
  532. }
  533. ///
  534. /// @ingroup ge
  535. /// @brief Create input node for root graph.
  536. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  537. /// @return 0: SUCCESS / others: FAILED
  538. ///
  539. Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) {
  540. // Data --> Case
  541. std::vector<NodePtr> all_data_nodes;
  542. size_t case_input_index = kCaseArgIndex;
  543. NodePtr getnext_node = nullptr;
  544. size_t input_index_of_getnext = 0;
  545. for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) {
  546. const auto &node = all_data_nodes_[i];
  547. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  548. if (op_desc == nullptr) {
  549. GELOGE(OUT_OF_MEMORY, "Create multi-batch Data node failed, name: %s", node->GetName().c_str());
  550. return FAILED;
  551. }
  552. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  553. return FAILED;
  554. }
  555. op_desc->SetName(node->GetName());
  556. const NodePtr &data = graph->AddNode(op_desc);
  557. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  558. if (IsGetNextType(node)) {
  559. getnext_node = data;
  560. input_index_of_getnext = case_input_index;
  561. case_input_index = case_input_index + data_count_from_getnext_;
  562. continue;
  563. } else {
  564. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) !=
  565. GRAPH_SUCCESS) {
  566. GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(),
  567. case_node_->GetName().c_str());
  568. return FAILED;
  569. }
  570. }
  571. if (SetMaxShape(data) != SUCCESS) {
  572. GELOGE(FAILED, "Set max shape of %s failed.", data->GetName().c_str());
  573. return FAILED;
  574. }
  575. all_data_nodes.emplace_back(data);
  576. }
  577. if (getnext_node != nullptr) {
  578. if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) {
  579. GELOGE(FAILED, "Failed to link edge for %s.", getnext_node->GetName().c_str());
  580. return FAILED;
  581. }
  582. if (SetMaxShape(getnext_node) != SUCCESS) {
  583. GELOGE(FAILED, "Set max shape of %s failed.", getnext_node->GetName().c_str());
  584. return FAILED;
  585. }
  586. all_data_nodes.emplace_back(getnext_node);
  587. }
  588. all_data_nodes_.swap(all_data_nodes);
  589. return SUCCESS;
  590. }
  591. Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) {
  592. GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(),
  593. case_input_index, case_node_->GetName().c_str());
  594. for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) {
  595. if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index),
  596. case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) {
  597. GELOGE(FAILED, "Failed to add data edge between %zu Data:%s to %zu Case:%s", out_index,
  598. getnext_node->GetName().c_str(), case_input_index, case_node_->GetName().c_str());
  599. return FAILED;
  600. }
  601. }
  602. if (getnext_sink_dynamic_dims_) {
  603. GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "Failed to add link for %s.",
  604. shape_node_->GetName().c_str());
  605. }
  606. return SUCCESS;
  607. }
  608. ///
  609. /// @ingroup ge
  610. /// @brief Create Const node for root graph.
  611. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  612. /// @return 0: SUCCESS / others: FAILED
  613. ///
  614. Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) {
  615. // Const --> Case
  616. std::vector<NodePtr> all_const_nodes;
  617. size_t arg_index = kCaseArgIndex + all_data_nodes_.size();
  618. if (data_count_from_getnext_ != 0) {
  619. arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode;
  620. }
  621. for (size_t i = 0; i < all_const_nodes_.size(); ++i) {
  622. const auto &node = all_const_nodes_[i];
  623. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  624. if (op_desc == nullptr) {
  625. GELOGE(OUT_OF_MEMORY, "Create multi-batch Const node failed, name: %s", node->GetName().c_str());
  626. return FAILED;
  627. }
  628. op_desc->SetName(node->GetName());
  629. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  630. return FAILED;
  631. }
  632. const NodePtr &data = graph->AddNode(op_desc);
  633. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  634. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  635. GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(),
  636. case_node_->GetName().c_str());
  637. return FAILED;
  638. }
  639. all_const_nodes.emplace_back(data);
  640. }
  641. ChangeConstToData();
  642. all_const_nodes_.swap(all_const_nodes);
  643. return SUCCESS;
  644. }
  645. void MultiBatchClonePass::ChangeConstToData() {
  646. size_t data_index = all_data_nodes_.size();
  647. if (data_count_from_getnext_ != 0) {
  648. data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode;
  649. }
  650. for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data.
  651. auto &const_node = all_const_nodes_[i];
  652. bool need_change_type = true;
  653. if (out_control_nodes_.find(const_node) != out_control_nodes_.end()) {
  654. GELOGD("No need to change %s to data type.", const_node->GetName().c_str());
  655. need_change_type = false;
  656. break;
  657. }
  658. if (!need_change_type) {
  659. continue;
  660. }
  661. const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc();
  662. op_desc->SetType(DATA);
  663. (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight.
  664. // Const no InputDesc, Data need InputDesc.
  665. (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex));
  666. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  667. (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1);
  668. }
  669. }
  670. ///
  671. /// @ingroup ge
  672. /// @brief Create output node for root graph.
  673. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  674. /// @return 0: SUCCESS / others: FAILED
  675. ///
  676. Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) {
  677. const auto &output = all_output_nodes_[0];
  678. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc());
  679. if (op_desc == nullptr) {
  680. GELOGE(OUT_OF_MEMORY, "Create multi-batch output node failed");
  681. return FAILED;
  682. }
  683. if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) {
  684. return FAILED;
  685. }
  686. op_desc->SetName(output->GetName());
  687. const NodePtr &node = graph->AddNode(op_desc);
  688. GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  689. for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) {
  690. const auto it = direct_output_.find(i);
  691. if (it == direct_output_.end()) {
  692. if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  693. GELOGE(FAILED, "Failed to add edge between Case:%s to NetOutput:%s",
  694. case_node_->GetName().c_str(), node->GetName().c_str());
  695. return FAILED;
  696. }
  697. } else {
  698. const auto data_node = graph->FindNode(it->second);
  699. if (data_node == nullptr) {
  700. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Data node:%s not found", it->second.c_str());
  701. return GE_GRAPH_GRAPH_NODE_NULL;
  702. }
  703. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  704. GELOGE(FAILED, "Failed to add edge between Data:%s to NetOutput:%s",
  705. data_node->GetName().c_str(), node->GetName().c_str());
  706. return FAILED;
  707. }
  708. }
  709. }
  710. GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "Failed to add edge between %s to netoutput: %s.",
  711. shape_node_->GetName().c_str(), output->GetName().c_str());
  712. all_output_nodes_.clear();
  713. all_output_nodes_.emplace_back(node);
  714. return SUCCESS;
  715. }
  716. ///
  717. /// @ingroup ge
  718. /// @brief Set max shape to Data node in root graph.
  719. /// @param [in] const NodePtr &data: data in Root/Case graph.
  720. /// @return 0: SUCCESS / others: FAILED
  721. ///
  722. Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) {
  723. GELOGD("Start set max shape for %s.", data->GetName().c_str());
  724. if (!IsGetNextType(data)) {
  725. if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) {
  726. GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str());
  727. return PARAM_INVALID;
  728. }
  729. } else {
  730. for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) {
  731. if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) {
  732. GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str());
  733. return PARAM_INVALID;
  734. }
  735. }
  736. }
  737. return SUCCESS;
  738. }
  739. Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) {
  740. GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index);
  741. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  742. string data_name = node->GetName();
  743. if (IsGetNextType(node)) {
  744. data_name.append("_").append(std::to_string(out_anchor_index));
  745. }
  746. GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(),
  747. formats::JoinToString(data_shape.GetDims()).c_str());
  748. const auto &dims = data_shape.GetDims();
  749. if (!IsGetNextType(node)) {
  750. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  751. GELOGD("No need to do anything for static data.");
  752. return SUCCESS;
  753. }
  754. } else {
  755. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  756. if (getnext_sink_dynamic_dims_) {
  757. // need to update shape of Shape_node when getnext node has dynamic data
  758. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node");
  759. }
  760. return SUCCESS;
  761. }
  762. }
  763. (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  764. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex));
  765. std::vector<std::string> input_dims_str;
  766. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  767. auto shape = data_shape;
  768. auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  769. if (ret != SUCCESS) {
  770. GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", node->GetName().c_str());
  771. return ret;
  772. }
  773. tensor.SetShape(shape);
  774. int64_t tensor_size = 0;
  775. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  776. string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  777. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" +
  778. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  779. formats::JoinToString(tensor.GetShape().GetDims());
  780. input_dims_str.emplace_back(input_str);
  781. }
  782. (void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  783. size_t max_shape_index = 0;
  784. int64_t max_size = 0;
  785. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  786. int64_t size = 1;
  787. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  788. if (INT64_MAX / dim < size) {
  789. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  790. formats::ShapeToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  791. return PARAM_INVALID;
  792. }
  793. size *= dim;
  794. }
  795. if (size > max_size) {
  796. max_size = size;
  797. max_shape_index = i;
  798. }
  799. }
  800. return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index);
  801. }
  802. ///
  803. /// @ingroup ge
  804. /// @brief Set max shape to Data/GetNext node in root graph.
  805. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  806. /// @param [in] const NodePtr &data: data in Root/Case graph.
  807. /// @param [in] GeShape &data_shape: dims of data node.
  808. /// @param [in] size_t out_anchor_index: out anchor index of data node.
  809. /// @return 0: SUCCESS / others: FAILED
  810. ///
  811. Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape,
  812. size_t out_anchor_index) {
  813. GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str());
  814. if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) {
  815. GELOGE(INTERNAL_ERROR, "Failed to calculate the batched shape for data node %s, the shapes may not match",
  816. data->GetName().c_str());
  817. return INTERNAL_ERROR;
  818. }
  819. if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) {
  820. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  821. return INTERNAL_ERROR;
  822. }
  823. if (!IsGetNextType(data)) {
  824. if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) {
  825. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  826. return INTERNAL_ERROR;
  827. }
  828. } else {
  829. if (getnext_sink_dynamic_dims_) {
  830. // need to update shape of Shape_node when getnext_sink_dynamic
  831. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index), "Failed to update shape of shape node");
  832. }
  833. }
  834. GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(),
  835. formats::ShapeToString(data_shape).c_str());
  836. return SUCCESS;
  837. }
  838. Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) {
  839. GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index,
  840. node->GetName().c_str());
  841. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  842. size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst);
  843. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index);
  844. std::vector<int64_t> output_dims = {static_cast<int64_t>(data_shape.GetDims().size())};
  845. GeShape output_shape(output_dims);
  846. output_desc.SetShape(output_shape);
  847. if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) {
  848. GELOGE(FAILED, "Update output desc fail.");
  849. return FAILED;
  850. }
  851. return SUCCESS;
  852. }
  853. ///
  854. /// @ingroup ge
  855. /// @brief Update Data node in Subgraph.
  856. /// @param [in] const NodePtr &data: data in Subgraph.
  857. /// @param [in] size_t batch_index: The batch index.
  858. /// @return 0: SUCCESS / others: FAILED
  859. ///
  860. Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) {
  861. int node_index = -1;
  862. if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) {
  863. GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str());
  864. return FAILED;
  865. }
  866. int parent_index = node_index + 1;
  867. if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  868. GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str());
  869. return FAILED;
  870. }
  871. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  872. const auto &dims = data_shape.GetDims();
  873. GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index,
  874. formats::JoinToString(dims).c_str());
  875. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  876. return SUCCESS;
  877. }
  878. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  879. auto data_name = data->GetName();
  880. size_t pos = data_name.find(kMultiBatchNodePostfix);
  881. if (pos == string::npos) {
  882. GELOGE(FAILED, "Cannot find key string [%s] of multi-batch in name of virtual input node, node name: %s.",
  883. kMultiBatchNodePostfix.c_str(), data_name.c_str());
  884. return FAILED;
  885. }
  886. auto parent_name = data_name.substr(0, pos);
  887. return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex);
  888. }
  889. Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) {
  890. if (data_count_from_getnext_ == 0) {
  891. GELOGD("No need to change original graph without getnext node.");
  892. return SUCCESS;
  893. }
  894. GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str());
  895. size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode;
  896. for (const auto &node : graph->GetDirectNode()) {
  897. if (IsGetNextType(node)) {
  898. for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) {
  899. auto out_data_anchor = node->GetOutDataAnchor(out_index);
  900. GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue);
  901. NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index);
  902. GE_IF_BOOL_EXEC(data_node == nullptr, GELOGE(INTERNAL_ERROR, "Create %zu data node failed.",
  903. out_data_anchor->GetIdx()); return INTERNAL_ERROR);
  904. for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  905. GE_IF_BOOL_EXEC(in_anchor == nullptr, continue);
  906. NodePtr dst_node = in_anchor->GetOwnerNode();
  907. if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) {
  908. GELOGE(INTERNAL_ERROR, "Failed to remove edge between %s to %s", node->GetName().c_str(),
  909. dst_node->GetName().c_str());
  910. return INTERNAL_ERROR;
  911. }
  912. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) !=
  913. GRAPH_SUCCESS) {
  914. GELOGE(INTERNAL_ERROR, "Failed to add edge between %s to %s", data_node->GetName().c_str(),
  915. dst_node->GetName().c_str());
  916. return INTERNAL_ERROR;
  917. }
  918. }
  919. }
  920. if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
  921. GELOGE(GRAPH_FAILED, "Remove node %s failed!", node->GetName().c_str());
  922. return GRAPH_FAILED;
  923. }
  924. break;
  925. }
  926. }
  927. return SUCCESS;
  928. }
  929. NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor,
  930. size_t data_index) {
  931. size_t out_anchor_index = out_data_anchor->GetIdx();
  932. std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index);
  933. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, DATA);
  934. if (op_desc == nullptr) {
  935. GELOGE(OUT_OF_MEMORY, "Create data node failed.");
  936. return nullptr;
  937. }
  938. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  939. OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
  940. if (getnext_op_desc == nullptr) {
  941. GELOGE(OUT_OF_MEMORY, "Op desc of %s is nullptr.", out_data_anchor->GetOwnerNode()->GetName().c_str());
  942. return nullptr;
  943. }
  944. if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) {
  945. GELOGE(INTERNAL_ERROR, "Add %s input desc failed.", op_desc->GetName().c_str());
  946. return nullptr;
  947. }
  948. if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) {
  949. GELOGE(INTERNAL_ERROR, "Add %s output desc failed.", op_desc->GetName().c_str());
  950. return nullptr;
  951. }
  952. NodePtr data_node = graph->AddNode(op_desc);
  953. GELOGD("Success create %s node.", data_node->GetName().c_str());
  954. return data_node;
  955. }
  956. ///
  957. /// @ingroup ge
  958. /// @brief Create nodes for root graph.
  959. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  960. /// @param [in] const ComputeGraphPtr &branch: original graph.
  961. /// @return 0: SUCCESS / others: FAILED
  962. ///
  963. Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) {
  964. GELOGD("Start create subgraphs for %s.", graph->GetName().c_str());
  965. const auto &op_desc = case_node_->GetOpDesc();
  966. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  967. std::vector<NodePtr> input_nodes;
  968. std::vector<NodePtr> output_nodes;
  969. const std::string postfix = kMultiBatchNodePostfix + std::to_string(i);
  970. ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes);
  971. GE_IF_BOOL_EXEC(subgraph == nullptr, GELOGE(FAILED, "Create multi-batch case node failed"); return FAILED);
  972. subgraph->SetName("Batch_" + std::to_string(i));
  973. subgraph->SetParentNode(case_node_);
  974. subgraph->SetParentGraph(graph);
  975. graph->AddSubgraph(subgraph->GetName(), subgraph);
  976. all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  977. GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]),
  978. "Update %s failed", all_branch_output_[subgraph]->GetName().c_str());
  979. const string key_name = "branches" + std::to_string(i);
  980. op_desc->AddSubgraphName(key_name);
  981. op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
  982. GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size());
  983. for (const auto &data : input_nodes) {
  984. GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str());
  985. }
  986. }
  987. // Origninal graph take as first subgraph, update node name.
  988. for (const auto &n : branch->GetDirectNode()) {
  989. const auto &op_desc = n->GetOpDesc();
  990. op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
  991. if (n->GetType() == DATA) {
  992. GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str());
  993. }
  994. }
  995. return SUCCESS;
  996. }
  997. ///
  998. /// @ingroup ge
  999. /// @brief Update output_node in Subgraph.
  1000. /// @param [in] const NodePtr &output_node: output_node in Subgraph.
  1001. /// @return 0: SUCCESS / others: FAILED
  1002. ///
  1003. Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) {
  1004. const auto &op_desc = output_node->GetOpDesc();
  1005. GE_CHECK_NOTNULL(op_desc);
  1006. for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
  1007. GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
  1008. GE_CHECK_NOTNULL(tensor);
  1009. if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
  1010. GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
  1011. return FAILED;
  1012. }
  1013. }
  1014. return SUCCESS;
  1015. }
  1016. ///
  1017. /// @ingroup ge
  1018. /// @brief Remove subgraph suspend output anchor.
  1019. /// @param [in] ComputeGraphPtr &graph: Parent compute graph.
  1020. /// @return 0: SUCCESS / others: FAILED
  1021. ///
  1022. Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) {
  1023. GELOGD("Start prune direct output.");
  1024. const auto &func_desc = case_node_->GetOpDesc();
  1025. uint32_t unused_num = 0;
  1026. uint32_t output_num = func_desc->GetOutputsSize();
  1027. for (size_t i = 0; i < output_num; ++i) {
  1028. bool is_unused_tensor = true;
  1029. for (const auto &item : all_branch_output_) {
  1030. const auto &netoutput = item.second;
  1031. GE_CHECK_NOTNULL(netoutput);
  1032. const auto in_anchor = netoutput->GetInDataAnchor(i);
  1033. if (in_anchor->GetPeerOutAnchor() != nullptr) {
  1034. is_unused_tensor = false;
  1035. break;
  1036. }
  1037. }
  1038. if (is_unused_tensor) {
  1039. unused_num++;
  1040. continue;
  1041. }
  1042. GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num), "Graph:%s Update output failed", graph->GetName().c_str());
  1043. }
  1044. if (unused_num == 0) {
  1045. return SUCCESS;
  1046. }
  1047. GE_CHK_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num), "Remove output failed");
  1048. for (const auto &item : all_branch_output_) {
  1049. GE_CHK_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num), "Remove input failed");
  1050. }
  1051. return SUCCESS;
  1052. }
  1053. ///
  1054. /// @ingroup ge
  1055. /// @brief Update subgraph suspend output tensor.
  1056. /// @param [in] parent_index: parent index for check.
  1057. /// @param [in] unused_num: total unused tensor.
  1058. /// @return 0: SUCCESS / others: FAILED
  1059. ///
  1060. Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) {
  1061. if (unused_num == 0) {
  1062. GELOGD("No need to update output tensor.");
  1063. return SUCCESS;
  1064. }
  1065. uint32_t update_index = parent_index - unused_num;
  1066. for (const auto &item : all_branch_output_) {
  1067. const auto &node = item.second;
  1068. const auto &new_anchor = node->GetInDataAnchor(update_index);
  1069. const auto &old_anchor = node->GetInDataAnchor(parent_index);
  1070. const auto &out_anchor = old_anchor->GetPeerOutAnchor();
  1071. const auto &out_node = out_anchor->GetOwnerNode();
  1072. const auto &op_desc = node->GetOpDesc();
  1073. (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index));
  1074. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed");
  1075. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  1076. case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index);
  1077. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed");
  1078. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str());
  1079. }
  1080. const auto &new_anchor = case_node_->GetOutDataAnchor(update_index);
  1081. const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index);
  1082. for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) {
  1083. const auto &in_node = in_anchor->GetOwnerNode();
  1084. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor), "Remove edge failed");
  1085. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str());
  1086. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor), "Add edge failed");
  1087. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  1088. case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index);
  1089. }
  1090. return SUCCESS;
  1091. }
  1092. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.