You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_copy_graph.cc 52 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/preprocess/multi_batch_copy_graph.h"
  17. #include <queue>
  18. #include <set>
  19. #include <string>
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/ge/ge_util.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "framework/common/string_util.h"
  26. #include "framework/common/types.h"
  27. #include "framework/omg/omg_inner_types.h"
  28. #include "graph/debug/ge_attr_define.h"
  29. #include "graph/ge_context.h"
  30. #include "graph/passes/multi_batch_clone_pass.h"
  31. #include "graph/passes/prune_pass.h"
  32. #include "graph/preprocess/multi_batch_options.h"
  33. #include "graph/utils/attr_utils.h"
  34. #include "graph/utils/graph_utils.h"
  35. #include "graph/utils/node_utils.h"
  36. #include "graph/utils/tensor_utils.h"
  37. #include "graph/utils/type_utils.h"
  38. #include "inc/pass_manager.h"
  39. #include "graph/common/local_context.h"
  40. using std::set;
  41. using std::string;
  42. using std::vector;
  43. using std::map;
  44. namespace ge {
  45. namespace multibatch {
  46. namespace {
  47. const char *const kMbatchSwitchnName = "mbatch-switch-name";
  48. const int kSwitchNDataIndex = 0;
  49. const int kSwitchNPredIndex = 1;
  50. const int kDataOutIndex = 0;
  51. const int kDataInIndex = 0;
  52. const int kMergeDataOutIndex = 0;
  53. const int kStaticOutput = -1;
  54. inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }
  55. NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) {
  56. OpDescPtr desc = MakeShared<OpDesc>();
  57. if (desc == nullptr) {
  58. GELOGE(OUT_OF_MEMORY, "Failed to insert merge node, name %s", name.c_str());
  59. return nullptr;
  60. }
  61. desc->SetName(name);
  62. desc->SetType(MERGE);
  63. GeTensorDesc tensor_desc;
  64. for (size_t i = 0; i < input_num; ++i) {
  65. auto ret = desc->AddInputDesc("x" + std::to_string(i), tensor_desc);
  66. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  67. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add input %zu, error-code %u",
  68. name.c_str(), i, ret);
  69. return nullptr);
  70. }
  71. auto ret = desc->AddOutputDesc("y", tensor_desc);
  72. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  73. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'y', error-code %u",
  74. name.c_str(), ret);
  75. return nullptr);
  76. tensor_desc.SetDataType(DT_INT32);
  77. ret = desc->AddOutputDesc("value_index", tensor_desc);
  78. if (ret != GRAPH_SUCCESS) {
  79. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'value_index', error-code %u",
  80. name.c_str(), ret);
  81. return nullptr;
  82. }
  83. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  84. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add attr", name.c_str());
  85. return nullptr;
  86. }
  87. return graph->AddNode(desc);
  88. }
  89. NodePtr InsertCopyNode(const NodePtr &node, size_t n) {
  90. const std::string &name = node->GetName() + "_ascend_mbatch_batch_" + std::to_string(n);
  91. auto src_op_desc = node->GetOpDesc();
  92. GE_IF_BOOL_EXEC(src_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "Failed to copy node %s to %s, the OpDesc is null",
  93. node->GetName().c_str(), name.c_str());
  94. return nullptr);
  95. auto desc = AttrUtils::CopyOpDesc(src_op_desc);
  96. GE_IF_BOOL_EXEC(desc == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to create op desc for copy node for node %s name %s",
  97. node->GetName().c_str(), name.c_str());
  98. return nullptr);
  99. desc->SetName(name);
  100. desc->CopyAttrsFrom(*src_op_desc);
  101. for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
  102. auto input_desc = desc->MutableInputDesc(i);
  103. GE_IF_BOOL_EXEC(input_desc == nullptr,
  104. GELOGW("Get null input desc by index %u from node %s when copy from %s", i,
  105. desc->GetName().c_str(), node->GetName().c_str());
  106. continue);
  107. input_desc->CopyAttrsFrom(src_op_desc->GetInputDesc(i));
  108. }
  109. for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
  110. auto output_desc = desc->MutableOutputDesc(i);
  111. GE_IF_BOOL_EXEC(output_desc == nullptr,
  112. GELOGE(INTERNAL_ERROR, "Failed to get output desc by index %u from node %s when copy from %s", i,
  113. desc->GetName().c_str(), node->GetName().c_str());
  114. return nullptr);
  115. output_desc->CopyAttrsFrom(src_op_desc->GetOutputDesc(i));
  116. }
  117. const std::string &batch_label = "Batch_" + std::to_string(n);
  118. if (!AttrUtils::SetStr(desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  119. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", name.c_str());
  120. return nullptr;
  121. }
  122. (void)AttrUtils::SetListStr(desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()});
  123. auto graph = node->GetOwnerComputeGraph();
  124. return graph->AddNode(desc);
  125. }
  126. bool IsAllDimsPositive(const std::vector<int64_t> &dims) {
  127. for (auto dim : dims) {
  128. if (dim < 0) {
  129. return false;
  130. }
  131. }
  132. return true;
  133. }
  134. NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) {
  135. auto desc = MakeShared<OpDesc>();
  136. if (desc == nullptr) {
  137. GELOGE(OUT_OF_MEMORY, "Failed to create const op %s, out of memory", name.c_str());
  138. return nullptr;
  139. }
  140. desc->SetName(name);
  141. desc->SetType(CONSTANT);
  142. GeTensor tensor;
  143. tensor.SetData(std::vector<uint8_t>({0}));
  144. if (!AttrUtils::SetTensor(desc, ATTR_NAME_WEIGHTS, tensor)) {
  145. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", name.c_str());
  146. return nullptr;
  147. }
  148. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  149. GELOGE(OUT_OF_MEMORY, "Failed to set insert flag for const node %s", name.c_str());
  150. return nullptr;
  151. }
  152. if (desc->AddOutputDesc(GeTensorDesc()) != GRAPH_SUCCESS) {
  153. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", name.c_str());
  154. return nullptr;
  155. }
  156. return graph->AddNode(desc);
  157. }
  158. bool IsOnlyOutputToAipp(const NodePtr &node) {
  159. for (const auto &out_node : node->GetOutDataNodes()) {
  160. if (out_node->GetType() != AIPP) {
  161. return false;
  162. }
  163. }
  164. return true;
  165. }
  166. Status CheckDataShape(const std::vector<NodePtr> &nodes) {
  167. size_t unknown_shape_count = 0;
  168. for (const auto &node : nodes) {
  169. if (node->GetType() != DATA) {
  170. continue;
  171. }
  172. for (auto dim : NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims()) {
  173. if (dim < 0) {
  174. unknown_shape_count++;
  175. break;
  176. }
  177. }
  178. }
  179. if (unknown_shape_count == 0) {
  180. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  181. GELOGE(PARAM_INVALID,
  182. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  183. return PARAM_INVALID;
  184. }
  185. return SUCCESS;
  186. }
  187. } // namespace
  188. Status MultiBatchGraphCopyer::CopyGraph() {
  189. auto ret = Init();
  190. if (ret != SUCCESS) {
  191. return ret;
  192. }
  193. if (LabelStatus() != SUCCESS) {
  194. GELOGE(INTERNAL_ERROR, "Failed to label status for all nodes.");
  195. return INTERNAL_ERROR;
  196. }
  197. ret = CheckAndParseDynamicData();
  198. if (ret != SUCCESS) {
  199. return ret;
  200. }
  201. ret = CreateNewNodes();
  202. if (ret != SUCCESS) {
  203. return ret;
  204. }
  205. ret = LinkEdges();
  206. if (ret != SUCCESS) {
  207. return ret;
  208. }
  209. ret = InsertIdentityAfterSwitchN();
  210. if (ret != SUCCESS) {
  211. GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node.");
  212. return INTERNAL_ERROR;
  213. }
  214. GELOGI("Begin to remove useless nodes by prune pass after copy process");
  215. PrunePass prune_pass;
  216. ret = prune_pass.Run(graph_);
  217. if (ret != SUCCESS) {
  218. GELOGE(ret, "Failed to prune");
  219. return ret;
  220. }
  221. return CheckCopyResult(origin_data_nodes_);
  222. }
  223. Status MultiBatchGraphCopyer::Init() {
  224. auto ret = CheckArguments();
  225. if (ret != SUCCESS) {
  226. return ret;
  227. }
  228. for (auto &node : graph_->GetAllNodes()) {
  229. origin_all_nodes_.emplace_back(node);
  230. if (IsDataLikeType(node->GetType())) {
  231. origin_data_nodes_.emplace_back(node);
  232. }
  233. }
  234. return SUCCESS;
  235. }
  236. Status MultiBatchGraphCopyer::LabelStatus() {
  237. map<string, vector<NodePtr>> frame_enters;
  238. InitStatus(frame_enters);
  239. bool changed = true;
  240. // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
  241. while (changed) {
  242. changed = false;
  243. for (const auto &node : origin_all_nodes_) {
  244. for (auto &in_node : node->GetInAllNodes()) {
  245. bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() &&
  246. origin_nodes_status_[in_node.get()] == kNodeInBatchBranch;
  247. if (is_in_batch) {
  248. if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end() ||
  249. origin_nodes_status_[node.get()] != kNodeInBatchBranch) {
  250. origin_nodes_status_[node.get()] = kNodeInBatchBranch;
  251. ResetEnterStatus(frame_enters, node);
  252. changed = true;
  253. }
  254. break;
  255. }
  256. }
  257. }
  258. }
  259. for (const auto &node : origin_all_nodes_) {
  260. if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) {
  261. origin_nodes_status_[node.get()] = kNodeNotSupportNode;
  262. continue;
  263. }
  264. if (node->GetType() == NETOUTPUT) {
  265. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  266. continue;
  267. }
  268. if (IsDataLikeType(node->GetType())) {
  269. if (IsOnlyOutputToAipp(node)) {
  270. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  271. } else {
  272. origin_nodes_status_[node.get()] = kNodeStartNode;
  273. }
  274. continue;
  275. }
  276. if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
  277. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  278. }
  279. }
  280. return SUCCESS;
  281. }
  282. void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) {
  283. for (const auto &node : origin_all_nodes_) {
  284. if (node->GetType() != ENTER && node->GetType() != REFENTER) {
  285. continue;
  286. }
  287. auto op_desc = node->GetOpDesc();
  288. if (op_desc == nullptr) {
  289. continue;
  290. }
  291. string frame_name;
  292. if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
  293. frame_enters[frame_name].emplace_back(node);
  294. }
  295. }
  296. for (const auto &data : origin_data_nodes_) {
  297. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  298. if (!IsAllDimsPositive(data_shape.GetDims())) {
  299. origin_nodes_status_[data.get()] = kNodeInBatchBranch;
  300. }
  301. }
  302. }
  303. void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) {
  304. if (node->GetType() != ENTER && node->GetType() != REFENTER) {
  305. return;
  306. }
  307. for (const auto &frame_enter : frame_enters) {
  308. auto &enters = frame_enter.second;
  309. if (std::find(enters.begin(), enters.end(), node) != enters.end()) {
  310. for (const auto &enter : enters) {
  311. origin_nodes_status_[enter.get()] = kNodeInBatchBranch;
  312. }
  313. break;
  314. }
  315. }
  316. }
  317. Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){
  318. size_t unknown_shape_count = 0;
  319. auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
  320. GELOGD("raw data_name_and_shape size: %zu", data_name_and_shape.size());
  321. for (const auto &node : origin_all_nodes_) {
  322. auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
  323. auto data_shape = data_desc.GetShape();
  324. auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
  325. data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
  326. auto data_name = node->GetName();
  327. auto branch_status = GetNodeStatus(node);
  328. if (branch_status != kNodeStartNode) {
  329. continue;
  330. }
  331. if (IsAllDimsPositive(data_shape.GetDims())) {
  332. continue;
  333. }
  334. ++unknown_shape_count;
  335. auto iter = find(data_name_order_.begin(), data_name_order_.end(), data_name);
  336. if (iter == data_name_order_.end()) {
  337. if (dynamic_type_ == DynamicType::kDynamicBatch) {
  338. auto ret = CheckDynamicBatchShape(data_shape.GetDims(), data_name);
  339. if (!ret) {
  340. return PARAM_INVALID;
  341. }
  342. } else if (dynamic_type_ == DynamicType::kDynamicImageSize) {
  343. auto ret = CheckDynamicImageSizeShape(data_shape.GetDims(), data_name, data_format);
  344. if (!ret) {
  345. return PARAM_INVALID;
  346. }
  347. } else if (dynamic_type_ == DynamicType::kDynamicDims) {
  348. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  349. {"parameter", "reason"},
  350. {"--input_shape",
  351. "all dynamic data must be set in --input_shape"});
  352. GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape",
  353. node->GetName().c_str(), data_shape.ToString().c_str());
  354. return INTERNAL_ERROR;
  355. }
  356. data_name_and_shape.emplace_back(data_name, data_shape.GetDims());
  357. }
  358. }
  359. auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_);
  360. if (ret != SUCCESS){
  361. return ret;
  362. }
  363. if (unknown_shape_count == 0) {
  364. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  365. GELOGE(PARAM_INVALID,
  366. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  367. return PARAM_INVALID;
  368. }
  369. return SUCCESS;
  370. }
  371. Status MultiBatchGraphCopyer::CreateNewNodes() {
  372. shape_data_ = InsertShapeDataNode();
  373. if (shape_data_ == nullptr) {
  374. GELOGE(INTERNAL_ERROR, "Failed to create the shape data node for muti-batch");
  375. return INTERNAL_ERROR;
  376. }
  377. for (const auto &node : origin_all_nodes_) {
  378. auto node_type = node->GetType();
  379. Status ret = INTERNAL_ERROR;
  380. auto branch_status = GetNodeStatus(node);
  381. GELOGD("Process node %s, status %d", node->GetName().c_str(), static_cast<int>(branch_status));
  382. switch (branch_status) {
  383. case kNodeStartNode:
  384. GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str());
  385. ret = InsertSwitchNForData(node);
  386. if (ret == SUCCESS) {
  387. ret = UpdateMaxShapeToData(node);
  388. }
  389. break;
  390. case kNodeInBatchBranch:
  391. GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  392. ret = CopyNodeInBatchBranch(node);
  393. break;
  394. case kNodeOutBatchBranch:
  395. GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  396. ret = InsertMergeForEdgeNode(node);
  397. break;
  398. case kNodeNotSupportNode:
  399. GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str());
  400. break;
  401. default:
  402. GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast<int>(branch_status),
  403. node->GetName().c_str());
  404. break;
  405. }
  406. if (ret != SUCCESS) {
  407. GELOGE(ret, "Failed to deal with node %s in multi-batch process", node->GetName().c_str());
  408. return ret;
  409. }
  410. }
  411. return SUCCESS;
  412. }
  413. NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) {
  414. if (index < 0) {
  415. // the merge node must has data inputs, if origin connection is a control
  416. // edge, we use data edge instead
  417. index = 0;
  418. }
  419. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  420. if (merge_nodes.empty()) {
  421. auto count = node->GetAllOutDataAnchorsSize();
  422. if (count == 0) {
  423. count = 1;
  424. }
  425. merge_nodes.resize(count, nullptr);
  426. }
  427. if (merge_nodes.at(index) != nullptr) {
  428. return merge_nodes[index];
  429. }
  430. auto merge_node_name = node->GetName() + "_ascend_mbatch_merge_" + std::to_string(index);
  431. auto merge_node = InsertMergeNodeToGraph(merge_node_name, shapes_.size(), node->GetOwnerComputeGraph());
  432. GE_IF_BOOL_EXEC(merge_node == nullptr, GELOGE(INTERNAL_ERROR, "Failed to create merge node for node %s, out index %d",
  433. node->GetName().c_str(), index);
  434. return nullptr);
  435. merge_nodes[index] = merge_node;
  436. GELOGI("Create merge node %s for node %s index %d", merge_node_name.c_str(), node->GetName().c_str(), index);
  437. return merge_node;
  438. }
  439. Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr &copyed_node) {
  440. for (auto &in_anchor : origin_node->GetAllInDataAnchors()) {
  441. auto origin_src_anchor = in_anchor->GetPeerOutAnchor();
  442. if (origin_src_anchor == nullptr) {
  443. GELOGD("The node %s does not have input on index %d", origin_node->GetName().c_str(), in_anchor->GetIdx());
  444. continue;
  445. }
  446. auto origin_src_node = origin_src_anchor->GetOwnerNode();
  447. auto dst_anchor = copyed_node->GetInDataAnchor(in_anchor->GetIdx());
  448. GE_CHECK_NOTNULL(dst_anchor);
  449. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  450. if (switchn_iter != data_nodes_to_switchn_.end()) {
  451. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutDataAnchor(batch_num), dst_anchor);
  452. if (ret != GRAPH_SUCCESS) {
  453. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  454. switchn_iter->second->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(),
  455. ret);
  456. return INTERNAL_ERROR;
  457. }
  458. GELOGD("Add data edge from %s(%d) to %s(%d)", switchn_iter->second->GetName().c_str(), batch_num,
  459. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  460. continue;
  461. }
  462. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  463. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  464. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  465. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutDataAnchor(origin_src_anchor->GetIdx()), dst_anchor);
  466. if (ret != GRAPH_SUCCESS) {
  467. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  468. src_batch_node->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(), ret);
  469. return INTERNAL_ERROR;
  470. }
  471. GELOGD("Add data edge from %s(%d) to %s(%d)", src_batch_node->GetName().c_str(), batch_num,
  472. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  473. continue;
  474. }
  475. auto ret = GraphUtils::AddEdge(origin_src_anchor, dst_anchor);
  476. if (ret != GRAPH_SUCCESS) {
  477. GELOGE(INTERNAL_ERROR, "Failed to add data edge between origin node %s(%d) to copyed %s(%d)",
  478. origin_src_node->GetName().c_str(), origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(),
  479. dst_anchor->GetIdx());
  480. return INTERNAL_ERROR;
  481. }
  482. GELOGD("Add data edge between branch-out %s(%d) to branch-in %s(%d)", origin_src_node->GetName().c_str(),
  483. origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(), dst_anchor->GetIdx());
  484. }
  485. return SUCCESS;
  486. }
  487. Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr &copyed_node) {
  488. for (auto &origin_src_node : node->GetInControlNodes()) {
  489. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  490. if (switchn_iter != data_nodes_to_switchn_.end()) {
  491. // reconnect data node
  492. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  493. if (ret != GRAPH_SUCCESS) {
  494. GELOGE(INTERNAL_ERROR, "Failed to add control edge between %s to %s, error-code %u",
  495. switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  496. return INTERNAL_ERROR;
  497. }
  498. GELOGD("Add control edge from %s to %s", switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str());
  499. continue;
  500. }
  501. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  502. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  503. // reconnect node in batch branch
  504. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  505. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  506. if (ret != GRAPH_SUCCESS) {
  507. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s to %s, error-code %u",
  508. src_batch_node->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  509. return INTERNAL_ERROR;
  510. }
  511. GELOGD("Add control edge from %s to %s", src_batch_node->GetName().c_str(), copyed_node->GetName().c_str());
  512. continue;
  513. }
  514. auto ret = GraphUtils::AddEdge(origin_src_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  515. if (ret != GRAPH_SUCCESS) {
  516. GELOGE(INTERNAL_ERROR, "Failed to add control edge from origin %s to copyed %s",
  517. origin_src_node->GetName().c_str(), copyed_node->GetName().c_str());
  518. return INTERNAL_ERROR;
  519. }
  520. GELOGD("Add control edge between branch-out %s to branch-in %s", origin_src_node->GetName().c_str(),
  521. copyed_node->GetName().c_str());
  522. }
  523. return SUCCESS;
  524. }
  525. NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() {
  526. auto desc = MakeShared<OpDesc>();
  527. if (desc == nullptr) {
  528. GELOGE(OUT_OF_MEMORY, "Failed to create shape data node, out of memory");
  529. return nullptr;
  530. }
  531. string node_name = "ascend_mbatch_shape_data";
  532. // Only flush subgraph name
  533. if (graph_->GetParentGraph() != nullptr) {
  534. node_name = graph_->GetName() + "_" + node_name;
  535. }
  536. desc->SetName(node_name);
  537. desc->SetType(DATA);
  538. GeTensorDesc tensor_desc;
  539. tensor_desc.SetFormat(FORMAT_ND);
  540. tensor_desc.SetShape(GeShape({static_cast<int64_t>(shapes_.at(0).size())}));
  541. tensor_desc.SetDataType(DT_INT64);
  542. auto ret = desc->AddInputDesc(tensor_desc);
  543. if (ret != GRAPH_SUCCESS) {
  544. GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  545. return nullptr;
  546. }
  547. ret = desc->AddOutputDesc(tensor_desc);
  548. if (ret != GRAPH_SUCCESS) {
  549. GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  550. return nullptr;
  551. }
  552. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  553. GELOGE(INTERNAL_ERROR, "Failed to add attr for created data");
  554. return nullptr;
  555. }
  556. auto data_node = graph_->AddNode(desc);
  557. if (data_node == nullptr) {
  558. GELOGE(INTERNAL_ERROR, "Failed to add shape data node to graph");
  559. return nullptr;
  560. }
  561. ret = GraphUtils::AppendInputNode(graph_, data_node);
  562. if (ret != GRAPH_SUCCESS) {
  563. GELOGE(INTERNAL_ERROR, "Failed to append data node %s as input to graph", data_node->GetName().c_str());
  564. return nullptr;
  565. }
  566. return data_node;
  567. }
  568. Status MultiBatchGraphCopyer::CheckArguments() {
  569. if (graph_ == nullptr) {
  570. GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null");
  571. return PARAM_INVALID;
  572. }
  573. return CheckDynamicParams(shapes_);
  574. }
  575. Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_nodes) {
  576. for (auto &node : start_nodes) {
  577. if (IsOnlyOutputToAipp(node)) {
  578. continue;
  579. }
  580. auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
  581. if (!IsAllDimsPositive(dims)) {
  582. GELOGE(INTERNAL_ERROR, "Failed to copy multi batch graph, the node %s still has unknown shape %s",
  583. node->GetName().c_str(), formats::ShapeToString(dims).c_str());
  584. return INTERNAL_ERROR;
  585. }
  586. }
  587. return SUCCESS;
  588. }
  589. bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) {
  590. return (nodes_to_batch_nodes_.count(node.get()) > 0) || (data_nodes_to_switchn_.count(node.get()) > 0);
  591. }
  592. Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge) {
  593. // The caller should make sure that the there is a SwitchN node in the map
  594. auto &switchn = data_nodes_to_switchn_[data.get()];
  595. GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(),
  596. switchn->GetName().c_str());
  597. for (size_t i = 0; i < shapes_.size(); ++i) {
  598. auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(i), merge->GetInDataAnchor(i));
  599. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  600. GELOGE(INTERNAL_ERROR, "Failed to add edge between switchn %s(%zu) to merge %s(%zu), error-code %u",
  601. switchn->GetName().c_str(), i, merge->GetName().c_str(), i, ret);
  602. return INTERNAL_ERROR);
  603. }
  604. return SUCCESS;
  605. }
  606. Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge) {
  607. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  608. if (copyed_nodes.size() != shapes_.size()) {
  609. GELOGE(INTERNAL_ERROR,
  610. "Failed to create merge node for node %s, the copyed nodes for it count %zu different with shape %zu",
  611. node->GetName().c_str(), copyed_nodes.size(), shapes_.size());
  612. return INTERNAL_ERROR;
  613. }
  614. for (size_t i = 0; i < copyed_nodes.size(); ++i) {
  615. auto src_node = copyed_nodes[i];
  616. if (src_node->GetAllOutDataAnchorsSize() == 0) {
  617. // if the node does not has any data output, we should create an const for it, like this:
  618. // c d
  619. // node ---> const ---> merge
  620. auto const_name = src_node->GetName() + "_merge_const";
  621. GELOGI("The node %s on the batch branch edge does not have any data output, create a const %s for it",
  622. src_node->GetName().c_str(), const_name.c_str());
  623. auto const_node = InsertConst(const_name, graph_);
  624. GE_IF_BOOL_EXEC(const_node == nullptr,
  625. GELOGE(OUT_OF_MEMORY, "Failed to create const for node %s to connect to a merge node",
  626. src_node->GetName().c_str());
  627. return OUT_OF_MEMORY);
  628. auto ret = GraphUtils::AddEdge(src_node->GetOutControlAnchor(), const_node->GetInControlAnchor());
  629. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  630. src_node->GetName().c_str(), const_node->GetName().c_str());
  631. return INTERNAL_ERROR);
  632. src_node = const_node;
  633. }
  634. auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i));
  635. if (ret != GRAPH_SUCCESS) {
  636. GELOGE(INTERNAL_ERROR,
  637. "Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u",
  638. copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret);
  639. return INTERNAL_ERROR;
  640. }
  641. }
  642. return SUCCESS;
  643. }
  644. Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) {
  645. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  646. auto data_name = data->GetName();
  647. if (IsAllDimsPositive(data_shape.GetDims())) {
  648. return SUCCESS;
  649. }
  650. size_t max_shape_index = 0;
  651. int64_t max_size = 0;
  652. for (size_t i = 0; i < shapes_.size(); ++i) {
  653. int64_t size = 1;
  654. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  655. if (INT64_MAX / dim < size) {
  656. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  657. formats::ShapeToString(data_to_dynamic_info_[data_name].at(i)).c_str());
  658. return PARAM_INVALID;
  659. }
  660. size *= dim;
  661. }
  662. if (size > max_size) {
  663. max_size = size;
  664. max_shape_index = i;
  665. }
  666. }
  667. // must not be error, the calc result has been checked in function InsertSwitchNForData
  668. (void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape);
  669. auto ret = NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape);
  670. if (ret != GRAPH_SUCCESS) {
  671. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  672. return INTERNAL_ERROR;
  673. }
  674. ret = NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape);
  675. if (ret != GRAPH_SUCCESS) {
  676. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  677. return INTERNAL_ERROR;
  678. }
  679. GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(),
  680. formats::ShapeToString(data_shape).c_str());
  681. return SUCCESS;
  682. }
  683. Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) {
  684. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  685. auto data_name = data->GetName();
  686. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  687. if (IsAllDimsPositive(data_shape.GetDims())) {
  688. GELOGI("The shape of data %s are positive(%s), skip the multi batch process", data->GetName().c_str(),
  689. data_shape.ToString().c_str());
  690. return SUCCESS;
  691. }
  692. auto switchn_desc = MakeShared<OpDesc>();
  693. if (switchn_desc == nullptr) {
  694. GELOGE(OUT_OF_MEMORY, "Failed to create switchn for data %s", data->GetName().c_str());
  695. return OUT_OF_MEMORY;
  696. }
  697. switchn_desc->SetName(data->GetName() + "_ascend_mbatch_switchn");
  698. switchn_desc->SetType(SWITCHN);
  699. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex));
  700. if (switchn_desc->AddInputDesc("data", tensor) != GRAPH_SUCCESS) { // data
  701. return OUT_OF_MEMORY;
  702. }
  703. GeTensorDesc pred_tensor;
  704. if (switchn_desc->AddInputDesc("pred_value", pred_tensor) != GRAPH_SUCCESS) { // pred
  705. return OUT_OF_MEMORY;
  706. }
  707. std::vector<std::string> input_dims_str;
  708. for (size_t i = 0; i < shapes_.size(); ++i) {
  709. auto shape = data_shape;
  710. auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  711. if (ret != SUCCESS) {
  712. GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match",
  713. data->GetName().c_str());
  714. return ret;
  715. }
  716. tensor.SetShape(shape);
  717. string input_str;
  718. int64_t tensor_size = 0;
  719. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  720. input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  721. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" +
  722. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  723. formats::JoinToString(tensor.GetShape().GetDims());
  724. input_dims_str.emplace_back(input_str);
  725. if (!AttrUtils::SetListInt(tensor, ATTR_NAME_SWITCHN_PRED_VALUE, shapes_.at(i))) {
  726. GELOGE(INTERNAL_ERROR, "Failed to add attr value on output %zu tensor", i);
  727. return INTERNAL_ERROR;
  728. }
  729. (void) AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims());
  730. if (switchn_desc->AddOutputDesc("output" + std::to_string(i), tensor) != GRAPH_SUCCESS) {
  731. GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed");
  732. return GRAPH_FAILED;
  733. }
  734. GELOGD("The SwitchN %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str());
  735. }
  736. (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  737. if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  738. GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s",
  739. switchn_desc->GetName().c_str());
  740. return INTERNAL_ERROR;
  741. }
  742. if (!AttrUtils::SetBool(switchn_desc, ATTR_INSERT_BY_MBATCH, true)) {
  743. GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str());
  744. return INTERNAL_ERROR;
  745. }
  746. if (!AttrUtils::SetStr(data->GetOpDesc(), kMbatchSwitchnName, switchn_desc->GetName())) {
  747. GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", data->GetName().c_str());
  748. return INTERNAL_ERROR;
  749. }
  750. if (StampDynamicType(switchn_desc) != SUCCESS) {
  751. GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr on switchn node %s", switchn_desc->GetName().c_str());
  752. return INTERNAL_ERROR;
  753. }
  754. auto switchn = graph_->AddNode(switchn_desc);
  755. if (switchn == nullptr) {
  756. GELOGE(OUT_OF_MEMORY, "Failed to create switchn %s from desc", switchn_desc->GetName().c_str());
  757. return OUT_OF_MEMORY;
  758. }
  759. data_nodes_to_switchn_[data.get()] = switchn;
  760. return SUCCESS;
  761. }
  762. Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) {
  763. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  764. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  765. if (src_out_anchor == nullptr) {
  766. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  767. continue;
  768. }
  769. auto in_node = src_out_anchor->GetOwnerNode();
  770. if (!IsInBatchBranch(in_node)) {
  771. continue;
  772. }
  773. auto merge_node = InsertMergeNode(in_node, src_out_anchor->GetIdx());
  774. if (merge_node == nullptr) {
  775. return INTERNAL_ERROR;
  776. }
  777. }
  778. for (auto &in_node : node->GetInControlNodes()) {
  779. if (!IsInBatchBranch(in_node)) {
  780. continue;
  781. }
  782. auto merge_node = InsertMergeNode(in_node, -1);
  783. if (merge_node == nullptr) {
  784. return INTERNAL_ERROR;
  785. }
  786. }
  787. return SUCCESS;
  788. }
  789. Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) {
  790. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  791. for (size_t i = 0; i < shapes_.size(); ++i) {
  792. auto copyed_node = InsertCopyNode(node, i);
  793. if (copyed_node == nullptr) {
  794. GELOGE(INTERNAL_ERROR, "Failed to add node to graph when copy node %s", node->GetName().c_str());
  795. return INTERNAL_ERROR;
  796. }
  797. copyed_nodes.emplace_back(copyed_node);
  798. GELOGI("Copy node %s type %s for shape %s, new node name %s", node->GetName().c_str(), node->GetType().c_str(),
  799. formats::JoinToString(shapes_.at(i)).c_str(), copyed_node->GetName().c_str());
  800. }
  801. return SUCCESS;
  802. }
  803. Status MultiBatchGraphCopyer::LinkEdges() {
  804. Status ret;
  805. for (const auto &node : origin_all_nodes_) {
  806. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  807. ret = LinkDataToSwitchN(node);
  808. if (ret != SUCCESS) {
  809. return ret;
  810. }
  811. }
  812. if (nodes_to_merge_nodes_.count(node.get()) > 0) {
  813. ret = LinkToMerge(node);
  814. if (ret != SUCCESS) {
  815. return ret;
  816. }
  817. }
  818. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  819. ret = LinkToNodeInBranch(node);
  820. } else {
  821. ret = LinkToNodeOutBranch(node);
  822. }
  823. if (ret != SUCCESS) {
  824. return ret;
  825. }
  826. }
  827. return SUCCESS;
  828. }
  829. Status MultiBatchGraphCopyer::LinkDataToSwitchN(const NodePtr &data) {
  830. auto switchn = data_nodes_to_switchn_[data.get()];
  831. auto ret =
  832. GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNPredIndex));
  833. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link shape data %s to switchn %s",
  834. shape_data_->GetName().c_str(), switchn->GetName().c_str());
  835. return INTERNAL_ERROR);
  836. ret = GraphUtils::AddEdge(data->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNDataIndex));
  837. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link data %s to switchn %s",
  838. data->GetName().c_str(), switchn->GetName().c_str());
  839. return INTERNAL_ERROR);
  840. return SUCCESS;
  841. }
  842. Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) {
  843. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  844. for (size_t i = 0; i < merge_nodes.size(); ++i) {
  845. auto merge_node = merge_nodes[i];
  846. if (merge_node == nullptr) {
  847. continue;
  848. }
  849. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  850. auto ret = LinkNodeToMerge(node, i, merge_node);
  851. if (ret != SUCCESS) {
  852. return ret;
  853. }
  854. continue;
  855. }
  856. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  857. auto ret = LinkDataToMerge(node, merge_node);
  858. if (ret != SUCCESS) {
  859. return ret;
  860. }
  861. continue;
  862. }
  863. GELOGE(INTERNAL_ERROR, "The merge node %s is created, index %zu, but can not find the src node",
  864. merge_node->GetName().c_str(), i);
  865. return INTERNAL_ERROR;
  866. }
  867. return SUCCESS;
  868. }
  869. Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) {
  870. auto &branch_nodes = nodes_to_batch_nodes_[node.get()];
  871. for (size_t i = 0; i < branch_nodes.size(); ++i) {
  872. auto ret = CopyInDataEdges(node, i, branch_nodes[i]);
  873. if (ret != SUCCESS) {
  874. return ret;
  875. }
  876. ret = CopyInControlEdges(node, i, branch_nodes[i]);
  877. if (ret != SUCCESS) {
  878. return ret;
  879. }
  880. }
  881. return SUCCESS;
  882. }
  883. Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
  884. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  885. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  886. if (src_out_anchor == nullptr) {
  887. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  888. continue;
  889. }
  890. auto in_node = src_out_anchor->GetOwnerNode();
  891. if (!IsInBatchBranch(in_node)) {
  892. continue;
  893. }
  894. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  895. if (iter == nodes_to_merge_nodes_.end()) {
  896. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  897. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  898. return INTERNAL_ERROR;
  899. }
  900. auto merge_node = iter->second[src_out_anchor->GetIdx()];
  901. if (merge_node == nullptr) {
  902. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  903. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  904. return INTERNAL_ERROR;
  905. }
  906. auto ret = src_out_anchor->Unlink(in_data_anchor);
  907. if (ret != GRAPH_SUCCESS) {
  908. GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s(%d) to %s(%d)", in_node->GetName().c_str(),
  909. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  910. return INTERNAL_ERROR;
  911. }
  912. ret = GraphUtils::AddEdge(merge_node->GetOutDataAnchor(kMergeDataOutIndex), in_data_anchor);
  913. if (ret != GRAPH_SUCCESS) {
  914. GELOGE(INTERNAL_ERROR, "Failed to add data edge from %s(%d) to %s(%d)", merge_node->GetName().c_str(),
  915. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  916. return INTERNAL_ERROR;
  917. }
  918. GELOGI("Link data edge from merge %s(from %s(%d)) to %s(%d)", merge_node->GetName().c_str(),
  919. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  920. }
  921. for (auto &in_node : node->GetInControlNodes()) {
  922. if (!IsInBatchBranch(in_node)) {
  923. continue;
  924. }
  925. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  926. if (iter == nodes_to_merge_nodes_.end()) {
  927. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  928. in_node->GetName().c_str(), node->GetName().c_str());
  929. return INTERNAL_ERROR;
  930. }
  931. auto merge_node = iter->second[0];
  932. if (merge_node == nullptr) {
  933. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  934. in_node->GetName().c_str(), node->GetName().c_str());
  935. return INTERNAL_ERROR;
  936. }
  937. GE_IF_BOOL_EXEC(in_node->GetOutControlAnchor() == nullptr,
  938. GELOGE(INTERNAL_ERROR, "Innode outputControlAnchor is null");
  939. return INTERNAL_ERROR);
  940. auto ret = in_node->GetOutControlAnchor()->Unlink(node->GetInControlAnchor());
  941. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s to %s",
  942. in_node->GetName().c_str(), node->GetName().c_str());
  943. return INTERNAL_ERROR);
  944. ret = GraphUtils::AddEdge(merge_node->GetOutControlAnchor(), node->GetInControlAnchor());
  945. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  946. merge_node->GetName().c_str(), node->GetName().c_str());
  947. return INTERNAL_ERROR);
  948. GELOGI("Link control edge from merge %s(from %s) to %s", merge_node->GetName().c_str(), in_node->GetName().c_str(),
  949. node->GetName().c_str());
  950. }
  951. return SUCCESS;
  952. }
  953. Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
  954. for (auto &node : graph_->GetAllNodes()) {
  955. if (node->GetType() != SWITCHN) {
  956. continue;
  957. }
  958. auto switchn_desc = node->GetOpDesc();
  959. GE_CHECK_NOTNULL(switchn_desc);
  960. size_t i = 0;
  961. for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  962. for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  963. auto out_node = in_data_anchor->GetOwnerNode();
  964. auto op_desc = out_node->GetOpDesc();
  965. GE_CHECK_NOTNULL(op_desc);
  966. if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  967. GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str());
  968. continue;
  969. }
  970. auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY);
  971. GE_CHECK_NOTNULL(identity_desc);
  972. string batch_label;
  973. if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  974. if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  975. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str());
  976. return FAILED;
  977. }
  978. }
  979. auto data_desc = switchn_desc->GetOutputDesc(i);
  980. i++;
  981. GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc));
  982. GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc));
  983. auto identity_node = graph_->AddNode(identity_desc);
  984. GE_CHECK_NOTNULL(identity_node);
  985. GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0)));
  986. GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor());
  987. GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor()));
  988. }
  989. }
  990. }
  991. return SUCCESS;
  992. }
  993. Status ProcessMultiBatch(ComputeGraphPtr &graph) {
  994. std::vector<std::vector<int64_t>> shapes;
  995. if (!InitDynamicParams(shapes)) {
  996. GELOGD("There is no multi-batch options, no need to process multi-batch copy");
  997. return SUCCESS;
  998. }
  999. DynamicType dynamic_type = DynamicType::kDynamicUnknown;
  1000. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  1001. dynamic_type = DynamicType::kDynamicBatch;
  1002. } else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  1003. dynamic_type = DynamicType::kDynamicImageSize;;
  1004. } else if (!GetLocalOmgContext().dynamic_dims.empty()) {
  1005. dynamic_type = DynamicType::kDynamicDims;
  1006. }
  1007. std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape;
  1008. user_designate_shape = GetLocalOmgContext().user_input_dims;
  1009. GELOGI("Begin to copy graph for multi-batch");
  1010. multibatch::MultiBatchGraphCopyer copyer(graph);
  1011. for (auto &shape : shapes) {
  1012. copyer.AddShape(shape);
  1013. }
  1014. copyer.SetDynamicType(dynamic_type);
  1015. copyer.SetUserDesignateShape(user_designate_shape);
  1016. return copyer.CopyGraph();
  1017. }
  1018. // +-----------+
  1019. // | Data | +-----------+ +-----------+ +-----------+
  1020. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1021. // \ /. +-----------+ +-----------+ +-----------+
  1022. // \ /.
  1023. // +-----------+ +-----------+ /. +-----------+ +-----------+ +-----------+
  1024. // | Data | ----> | Case | S--- | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1025. // +-----------+ +-----------+ \. +-----------+ +-----------+ +-----------+
  1026. // \ \.
  1027. // \ \. +-----------+ +-----------+ +-----------+
  1028. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1029. // | NetOutput | +-----------+ +-----------+ +-----------+
  1030. // +-----------+
  1031. // +-----------+ /
  1032. // | Data | --------------->/
  1033. // +-----------+
  1034. void GetDynamicShapeByGraph(const ComputeGraphPtr &graph, const NodePtr &node,
  1035. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1036. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1037. const auto &func_desc = node->GetOpDesc();
  1038. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  1039. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1040. return;
  1041. }
  1042. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  1043. for (size_t i = 0; i < func_desc->GetOutputsSize(); ++i) {
  1044. for (size_t j = 0; j < dynamic_branch_names.size(); ++j) {
  1045. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[j]);
  1046. if (subgraph == nullptr) {
  1047. GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", dynamic_branch_names[j].c_str());
  1048. dynamic_output_dims.clear();
  1049. return;
  1050. }
  1051. const auto &out_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  1052. if (out_node == nullptr) {
  1053. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "NetOutput not found, name: %s", dynamic_branch_names[j].c_str());
  1054. dynamic_output_dims.clear();
  1055. return;
  1056. }
  1057. GELOGI("Find the subgraph Output node %s and the index is %zu", out_node->GetName().c_str(), i);
  1058. const auto &out_desc = out_node->GetOpDesc();
  1059. if (out_desc == nullptr || out_desc->GetInputsSize() <= i) {
  1060. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Get Input desc failed, name: %s, index: %zu", out_node->GetName().c_str(), i);
  1061. dynamic_output_dims.clear();
  1062. return;
  1063. }
  1064. const auto &input_tensor = out_desc->GetInputDesc(i);
  1065. const auto &shape_msg = input_tensor.GetShape().ToString();
  1066. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1067. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1068. dynamic_output_dims.emplace_back(output_shape);
  1069. uint32_t parent_index = 0;
  1070. (void)AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  1071. dynamic_output_index.insert(parent_index);
  1072. }
  1073. }
  1074. }
  1075. // +-----------+ +-----------+ i = 0
  1076. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> \.
  1077. // / +-----------+ +-----------+ \.
  1078. // / \.
  1079. // +-----------+ +-----------+ +-----------+ +-----------+ i = 1 +-----------+
  1080. // | Data | ----> | SwitchN | ----> | SoftmaxV2 | ----> |MemcpyAsync| ----> | Merge |
  1081. // +-----------+ +-----------+ +-----------+ +-----------+ +-----------+
  1082. // \ / \. j = 0
  1083. // \ +-----------+ +-----------+ i = 2 / \.
  1084. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> / +-----------+
  1085. // +-----------+ +-----------+ | NetOutput |
  1086. // +-----------+
  1087. // +-----------+ /.
  1088. // | Data | --------------------------------------------------------------------------->/. j = 1
  1089. // +-----------+
  1090. void GetDynamicShapeByMerge(const ComputeGraphPtr &graph, const NodePtr &node,
  1091. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1092. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1093. const auto &netoutput_desc = node->GetOpDesc();
  1094. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1095. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1096. bool insert_by_mbatch = false;
  1097. (void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch);
  1098. if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) {
  1099. GELOGI("Find the merge node %s with mbatch attr and the index is %zu",
  1100. inputnode_to_netoutput.at(i)->GetName().c_str(), i);
  1101. dynamic_output_index.insert(i);
  1102. for (size_t j = 0; j < inputnode_to_netoutput.at(i)->GetInNodes().size(); ++j) {
  1103. auto input_desc = inputnode_to_netoutput.at(i)->GetOpDesc();
  1104. auto input_tensor_desc = input_desc->GetInputDesc(j);
  1105. auto shape_msg = input_tensor_desc.GetShape().ToString();
  1106. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1107. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1108. dynamic_output_dims.emplace_back(output_shape);
  1109. }
  1110. }
  1111. }
  1112. }
  1113. // Connect NetOutput directly: DTS2020070612498
  1114. void GetDirectOutputShape(const ComputeGraphPtr &graph, const NodePtr &node,
  1115. const set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1116. GELOGD("Try get directly shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1117. const auto &netoutput_desc = node->GetOpDesc();
  1118. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1119. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1120. if (dynamic_output_index.count(i) > 0) {
  1121. continue;
  1122. }
  1123. auto tensor_desc = netoutput_desc->GetInputDesc(i);
  1124. auto shape = tensor_desc.GetShape().ToString();
  1125. string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(i) + "," + shape;
  1126. GELOGI("The static output shape msg is %s", static_output_shape.c_str());
  1127. dynamic_output_dims.emplace_back(static_output_shape);
  1128. }
  1129. }
  1130. Status GetDynamicOutputShape(ComputeGraphPtr &graph) {
  1131. GE_CHECK_NOTNULL(graph);
  1132. GELOGI("Start to get output dynamic batch shape message");
  1133. NodePtr net_output;
  1134. set<size_t> dynamic_output_index;
  1135. vector<string> dynamic_output_dims;
  1136. for (auto &node : graph->GetDirectNode()) {
  1137. if (node->GetType() == NETOUTPUT) {
  1138. net_output = node;
  1139. GetDynamicShapeByMerge(graph, node, dynamic_output_index, dynamic_output_dims);
  1140. } else if (node->GetType() == CASE) {
  1141. GetDynamicShapeByGraph(graph, node, dynamic_output_index, dynamic_output_dims);
  1142. }
  1143. }
  1144. if ((net_output != nullptr) && !dynamic_output_dims.empty()) {
  1145. GetDirectOutputShape(graph, net_output, dynamic_output_index, dynamic_output_dims);
  1146. if (!AttrUtils::SetListStr(net_output->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) {
  1147. GELOGE(FAILED, "Set dynamic output dims attr failed");
  1148. return FAILED;
  1149. }
  1150. }
  1151. return SUCCESS;
  1152. }
  1153. } // namespace multibatch
  1154. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示