You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_pass.cc 35 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_pass.h"
  17. #include <stack>
  18. #include <unordered_set>
  19. #include "common/ge/ge_util.h"
  20. #include "graph/common/omg_util.h"
  21. #include "graph/utils/type_utils.h"
  22. #include "common/formats/utils/formats_trans_utils.h"
  23. namespace ge {
  24. Status MultiBatchPass::Run(ComputeGraphPtr graph) {
  25. GELOGD("MultiBatchPass Enter");
  26. if (graph->GetParentGraph() != nullptr) {
  27. GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str());
  28. return SUCCESS;
  29. }
  30. OutDataAnchorPtr pred_value = nullptr;
  31. Status ret = FindPredValue(graph, pred_value);
  32. if (ret == NOT_CHANGED) {
  33. GELOGD("SwitchN node not exist, graph not changed.");
  34. return SUCCESS;
  35. }
  36. if (ret != SUCCESS) {
  37. GELOGE(FAILED, "FindPredValue failed.");
  38. return FAILED;
  39. }
  40. if (GetDynamicType() != SUCCESS) {
  41. GELOGE(FAILED, "Get dynamic type failed.");
  42. return FAILED;
  43. }
  44. if (GetUserDesignateShape() != SUCCESS) {
  45. GELOGE(FAILED, "Get user designate shape failed.");
  46. return FAILED;
  47. }
  48. std::vector<std::vector<int64_t>> batch_shape;
  49. std::vector<std::vector<int64_t>> combined_batch;
  50. if (!CheckSwitchN(batch_shape, combined_batch)) {
  51. GELOGE(FAILED, "CheckSwitchN failed.");
  52. return FAILED;
  53. }
  54. if (attach_label_only_) {
  55. return AttachLabelOnly(batch_shape.size());
  56. }
  57. if (FindSwitchOutNodes(batch_shape.size()) != SUCCESS) {
  58. GELOGE(FAILED, "Find SwitchN out nodes failed.");
  59. return FAILED;
  60. }
  61. if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) {
  62. GELOGE(FAILED, "Replace SwitchN nodes failed.");
  63. return FAILED;
  64. }
  65. for (const NodePtr &node : bypass_nodes_) {
  66. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  67. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  68. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  69. GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str());
  70. return FAILED;
  71. }
  72. }
  73. GELOGD("MultiBatchPass Leave");
  74. return SUCCESS;
  75. }
  76. ///
  77. /// @brief Clear Status
  78. /// @return
  79. ///
  80. Status MultiBatchPass::ClearStatus() {
  81. switch_n_nodes_.clear();
  82. bypass_nodes_.clear();
  83. batch_head_nodes_.clear();
  84. return SUCCESS;
  85. }
  86. ///
  87. /// @ingroup ge
  88. /// @brief Set batch label for Case mode.
  89. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  90. /// @param [in] const NodePtr &case_node: Case Node.
  91. /// @return 0: SUCCESS / others: FAILED
  92. ///
  93. Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
  94. const auto &func_desc = case_node->GetOpDesc();
  95. GE_CHECK_NOTNULL(func_desc);
  96. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  97. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
  98. return SUCCESS;
  99. }
  100. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  101. for (size_t i = 0; i < dynamic_branch_names.size(); ++i) {
  102. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
  103. GE_CHECK_NOTNULL(subgraph);
  104. const std::string batch_label = "Batch_" + std::to_string(i);
  105. for (const auto &node : subgraph->GetDirectNode()) {
  106. (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  107. }
  108. }
  109. return SUCCESS;
  110. }
  111. ///
  112. /// @brief Replace & Combine SwitchN nodes
  113. /// @param [in] graph
  114. /// @param [out] pred_value
  115. /// @return Status
  116. ///
  117. Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) {
  118. for (const NodePtr &node : graph->GetDirectNode()) {
  119. if (node->GetType() == CASE) {
  120. GE_CHK_STATUS_RET(SetCaseLabel(graph, node), "Set batch label failed");
  121. continue;
  122. }
  123. if (node->GetType() != SWITCHN) {
  124. continue;
  125. }
  126. const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  127. if (in_data_anchor == nullptr) {
  128. REPORT_INNER_ERROR("E19999", "Index:%u data anchor of node:%s(%s) is nullptr, check invalid",
  129. SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str());
  130. GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
  131. return FAILED;
  132. }
  133. const auto &pred_input = in_data_anchor->GetPeerOutAnchor();
  134. if (pred_input == nullptr) {
  135. REPORT_INNER_ERROR("E19999", "Index:%u data anchor of node:%s(%s), its peer anchor is nullptr, check invalid",
  136. SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str());
  137. GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
  138. return FAILED;
  139. }
  140. if (pred_value == nullptr) {
  141. pred_value = pred_input;
  142. } else if (pred_value != pred_input) {
  143. REPORT_INNER_ERROR("E19999", "Multi pred_value of case node exist in graph:%s, check invalid",
  144. graph->GetName().c_str());
  145. GELOGE(FAILED, "Multi pred_value node exist.");
  146. return FAILED;
  147. }
  148. switch_n_nodes_.emplace_back(node);
  149. }
  150. if (switch_n_nodes_.empty()) {
  151. GELOGD("SwitchN node not exist.");
  152. return NOT_CHANGED;
  153. }
  154. if (pred_value == nullptr) {
  155. REPORT_INNER_ERROR("E19999", "Find Pred Input of case node in graph:%s failed", graph->GetName().c_str());
  156. GELOGE(FAILED, "FindPredInput failed, pred_value is null.");
  157. return FAILED;
  158. }
  159. GELOGI("Find pred_value %s.", pred_value->GetOwnerNode()->GetName().c_str());
  160. return SUCCESS;
  161. }
  162. ///
  163. /// @brief Get dynamic type: dynamic batch size: 1, dynamic image size: 2, dynamic dims: 3
  164. /// @return Status
  165. ///
  166. Status MultiBatchPass::GetDynamicType() {
  167. for (const auto &switch_n : switch_n_nodes_) {
  168. int32_t dynamic_type = static_cast<int32_t>(FIXED);
  169. if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) {
  170. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(),
  171. switch_n->GetName().c_str(), switch_n->GetType().c_str());
  172. GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str());
  173. return FAILED;
  174. }
  175. if (dynamic_type == static_cast<int32_t>(FIXED)) {
  176. REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%d check invalid", ATTR_DYNAMIC_TYPE.c_str(),
  177. switch_n->GetName().c_str(), switch_n->GetType().c_str(), dynamic_type);
  178. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  179. return FAILED;
  180. }
  181. if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
  182. REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%d not same as attr value:%d in node before, "
  183. "check invalid",
  184. ATTR_DYNAMIC_TYPE.c_str(), switch_n->GetName().c_str(), switch_n->GetType().c_str(),
  185. dynamic_type, dynamic_type_);
  186. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.",
  187. dynamic_type, dynamic_type_);
  188. return FAILED;
  189. }
  190. dynamic_type_ = dynamic_type;
  191. }
  192. if (dynamic_type_ == static_cast<int32_t>(FIXED)) {
  193. REPORT_INNER_ERROR("E19999", "Find Attr:%s in all switcnn node failed", ATTR_DYNAMIC_TYPE.c_str());
  194. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  195. return FAILED;
  196. }
  197. return SUCCESS;
  198. }
  199. ///
  200. /// @brief Get user designate shape order. eg{"data","label","mask"}
  201. /// @return Status
  202. ///
  203. Status MultiBatchPass::GetUserDesignateShape() {
  204. data_name_order_.clear();
  205. bool first_check = true;
  206. for (const auto &switch_n : switch_n_nodes_) {
  207. std::vector<std::string> cur_data_name_order;
  208. if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) {
  209. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  210. switch_n->GetName().c_str(), switch_n->GetType().c_str());
  211. GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str());
  212. return FAILED;
  213. }
  214. if (first_check) {
  215. data_name_order_ = cur_data_name_order;
  216. first_check = false;
  217. } else {
  218. if (data_name_order_ != cur_data_name_order) {
  219. REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s not same as attr value:%s in node before, "
  220. "check invalid", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  221. switch_n->GetName().c_str(), switch_n->GetType().c_str(),
  222. formats::JoinToString(cur_data_name_order).c_str(),
  223. formats::JoinToString(data_name_order_).c_str());
  224. GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
  225. switch_n->GetName().c_str());
  226. return FAILED;
  227. }
  228. }
  229. }
  230. if (data_name_order_.empty()) {
  231. REPORT_INNER_ERROR("E19999", "Find Attr:%s in all switcnn node failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str());
  232. GELOGE(FAILED, "user shape order can not be empty");
  233. return FAILED;
  234. }
  235. return SUCCESS;
  236. }
  237. ///
  238. /// @brief Check SwitchN nodes
  239. /// @param [out] batch_shape
  240. /// @param [out] combined_batch
  241. /// @return bool
  242. ///
  243. bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape,
  244. std::vector<std::vector<int64_t>> &combined_batch) {
  245. // Check if output_num of different SwitchN is same
  246. uint32_t batch_num = 0;
  247. for (const NodePtr &node : switch_n_nodes_) {
  248. uint32_t tmp_num = node->GetAllOutDataAnchorsSize();
  249. if (batch_num == 0) {
  250. batch_num = tmp_num;
  251. } else if (batch_num != tmp_num) {
  252. REPORT_INNER_ERROR("E19999", "Ouput size num:%u of node:%s(%s) not same as output size num:%d of node before, "
  253. "check invalid", tmp_num, node->GetName().c_str(), node->GetType().c_str(), batch_num);
  254. GELOGE(FAILED, "Output size of SwitchN not equal;");
  255. return false;
  256. }
  257. }
  258. if (!GetBatchInfo(batch_num, batch_shape, combined_batch)) {
  259. GELOGE(FAILED, "Get batch info failed.");
  260. return false;
  261. }
  262. if (batch_shape.empty()) {
  263. REPORT_INNER_ERROR("E19999", "batch_shape size is empty after GetBatchInfo, check invalid");
  264. GELOGE(FAILED, "batch_shape is empty.");
  265. return false;
  266. }
  267. if (combined_batch.empty()) {
  268. REPORT_INNER_ERROR("E19999", "combined_batch size is empty after GetBatchInfo, check invalid");
  269. GELOGE(FAILED, "combined_batch is empty.");
  270. return false;
  271. }
  272. size_t dim_num = batch_shape[0].size();
  273. size_t combined_dim_num = combined_batch[0].size();
  274. for (uint32_t i = 1; i < batch_num; i++) {
  275. size_t tmp_dim_num = batch_shape[i].size();
  276. if (dim_num != tmp_dim_num) {
  277. REPORT_INNER_ERROR("E19999", "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu, check invalid",
  278. dim_num, i, tmp_dim_num);
  279. GELOGE(FAILED, "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
  280. return false;
  281. }
  282. size_t tmp_combined_dim_num = combined_batch[i].size();
  283. if (combined_dim_num != tmp_combined_dim_num) {
  284. REPORT_INNER_ERROR("E19999", "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu, check invalid",
  285. combined_dim_num, i, tmp_combined_dim_num);
  286. GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.",
  287. combined_dim_num, i, tmp_combined_dim_num);
  288. return false;
  289. }
  290. }
  291. return true;
  292. }
  293. ///
  294. /// @brief Check SwitchN nodes
  295. /// @param [in] batch_num
  296. /// @param [out] batch_shape
  297. /// @param [out] combined_batch
  298. /// @return bool
  299. ///
  300. bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape,
  301. std::vector<std::vector<int64_t>> &combined_batch) {
  302. // Check if output_shape of different SwitchN is same
  303. std::vector<std::vector<int64_t>> idx_batch_shape;
  304. std::vector<std::vector<int64_t>> idx_combined_batch;
  305. for (uint32_t i = 0; i < batch_num; i++) {
  306. idx_batch_shape.clear();
  307. idx_combined_batch.clear();
  308. for (const NodePtr &node : switch_n_nodes_) {
  309. OpDescPtr op_desc = node->GetOpDesc();
  310. if (op_desc == nullptr) {
  311. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  312. GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
  313. return false;
  314. }
  315. std::vector<int64_t> output_dims;
  316. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
  317. REPORT_CALL_ERROR("E19999", "Get Attr:%s from output:%u tensor of op:%s(%s) failed",
  318. ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i,
  319. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  320. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
  321. return false;
  322. }
  323. idx_batch_shape.emplace_back(output_dims);
  324. output_dims.clear();
  325. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) {
  326. REPORT_CALL_ERROR("E19999", "Get Attr:%s from output:%u tensor of op:%s(%s) failed",
  327. ATTR_NAME_COMBINED_DYNAMIC_DIMS.c_str(), i,
  328. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  329. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_COMBINED_DYNAMIC_DIMS failed, batch_index=%u.", i);
  330. return false;
  331. }
  332. idx_combined_batch.emplace_back(output_dims);
  333. }
  334. if (!CheckDims(idx_batch_shape)) {
  335. REPORT_INNER_ERROR("E19999", "Attr:%s of all output:%u tensor in switcnn node not equal, or not exist, "
  336. "check invalid", ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i);
  337. GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i);
  338. return false;
  339. }
  340. batch_shape.emplace_back(idx_batch_shape[0]);
  341. combined_batch.emplace_back(idx_combined_batch[0]);
  342. }
  343. return true;
  344. }
  345. ///
  346. /// @brief Find outputs of SwitchN nodes
  347. /// @param [in] batch_num
  348. /// @return void
  349. ///
  350. Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
  351. std::vector<NodePtr> output_nodes;
  352. for (uint32_t i = 0; i < batch_num; i++) {
  353. output_nodes.clear();
  354. for (const NodePtr &node : switch_n_nodes_) {
  355. // idx is promised to be valid
  356. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  357. GE_CHECK_NOTNULL(out_data_anchor);
  358. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  359. auto out_node = peer_in_anchor->GetOwnerNode();
  360. if (out_node->GetType() != IDENTITY || !out_node->GetOutDataNodes().empty()) {
  361. output_nodes.emplace_back(out_node);
  362. continue;
  363. }
  364. bypass_nodes_.emplace_back(out_node);
  365. if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  366. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  367. node->GetName().c_str(), node->GetType().c_str(), i,
  368. out_node->GetName().c_str(), out_node->GetType().c_str(), peer_in_anchor->GetIdx());
  369. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  370. out_node->GetName().c_str());
  371. return FAILED;
  372. }
  373. for (auto &identity_out_node : out_node->GetOutControlNodes()) {
  374. output_nodes.emplace_back(identity_out_node);
  375. if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) !=
  376. GRAPH_SUCCESS) {
  377. REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s) failed",
  378. out_node->GetName().c_str(), out_node->GetType().c_str(),
  379. identity_out_node->GetName().c_str(), identity_out_node->GetType().c_str());
  380. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  381. out_node->GetName().c_str());
  382. return FAILED;
  383. }
  384. }
  385. }
  386. }
  387. batch_head_nodes_.emplace_back(output_nodes);
  388. }
  389. return SUCCESS;
  390. }
  391. ///
  392. /// @brief Replace & Combine SwitchN nodes
  393. /// @param [in] graph
  394. /// @param [in] pred_value
  395. /// @param [in] batch_shape
  396. /// @param [in] combined_batch
  397. /// @return Status
  398. ///
  399. Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
  400. const std::vector<std::vector<int64_t>> &batch_shape,
  401. const std::vector<std::vector<int64_t>> &combined_batch) {
  402. NodePtr pred_value_node = pred_value->GetOwnerNode();
  403. // Create SwitchCase node
  404. const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
  405. NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape, combined_batch);
  406. if (switch_case == nullptr) {
  407. GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str());
  408. return FAILED;
  409. }
  410. for (const NodePtr &switch_n_node : switch_n_nodes_) {
  411. if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) {
  412. GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str());
  413. return FAILED;
  414. }
  415. }
  416. // Add switchCase input edge
  417. if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  418. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  419. pred_value_node->GetName().c_str(), pred_value_node->GetType().c_str(), pred_value->GetIdx(),
  420. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  421. GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(),
  422. switch_case->GetName().c_str());
  423. return FAILED;
  424. }
  425. if (AttachLabel(switch_case) != SUCCESS) {
  426. GELOGE(FAILED, "AttachLabel failed.");
  427. return FAILED;
  428. }
  429. return SUCCESS;
  430. }
  431. ///
  432. /// @brief Check if output_shape of different SwitchN is same
  433. /// @param [in] output_shape
  434. /// @return bool
  435. ///
  436. bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_shape) const {
  437. if (output_shape.empty()) {
  438. GELOGE(FAILED, "CheckDims failed: output_shape is empty.");
  439. return false;
  440. }
  441. for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) {
  442. if (output_shape[0] != *iter) {
  443. return false;
  444. }
  445. }
  446. return true;
  447. }
  448. ///
  449. /// @brief Create StreamSwitchN node
  450. /// @param [in] graph
  451. /// @param [in] name
  452. /// @param [in] pred_value
  453. /// @param [in] batch_shape
  454. /// @param [in] combined_batch
  455. /// @return ge::NodePtr
  456. ///
  457. NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
  458. const OutDataAnchorPtr &pred_value,
  459. const std::vector<std::vector<int64_t>> &batch_shape,
  460. const std::vector<std::vector<int64_t>> &combined_batch) {
  461. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
  462. if (op_desc == nullptr) {
  463. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  464. GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
  465. return nullptr;
  466. }
  467. GELOGI("Create StreamSwitchN op:%s.", name.c_str());
  468. OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc();
  469. if (pred_desc == nullptr) {
  470. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  471. GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str());
  472. return nullptr;
  473. }
  474. if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) {
  475. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  476. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  477. GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str());
  478. return nullptr;
  479. }
  480. NodePtr switch_case_node = graph->AddNode(op_desc);
  481. if (switch_case_node == nullptr) {
  482. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  483. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  484. GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str());
  485. return nullptr;
  486. }
  487. uint32_t batch_num = static_cast<uint32_t>(batch_shape.size());
  488. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  489. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_NUM.c_str(),
  490. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  491. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str());
  492. return nullptr;
  493. }
  494. if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) {
  495. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(),
  496. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  497. GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str());
  498. return nullptr;
  499. }
  500. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  501. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  502. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  503. GELOGE(FAILED, "Set attr ATTR_USER_DESIGNEATE_SHAPE_ORDER failed, StreamSwitchN:%s.", name.c_str());
  504. return nullptr;
  505. }
  506. for (uint32_t i = 0; i < batch_num; i++) {
  507. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  508. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) {
  509. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_name.c_str(),
  510. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  511. GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
  512. return nullptr;
  513. }
  514. const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
  515. if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
  516. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_combined_batch.c_str(),
  517. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  518. GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
  519. return nullptr;
  520. }
  521. }
  522. return switch_case_node;
  523. }
  524. ///
  525. /// @brief Bypass SwitchN node
  526. /// @param [in] switch_n_node
  527. /// @param [in] switch_case
  528. /// @return Status
  529. ///
  530. Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) {
  531. InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  532. if (in_data_anchor == nullptr) {
  533. REPORT_INNER_ERROR("E19999", "Index:%u in data anchor of node:%s(%s) is nullptr, check invalid",
  534. SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str());
  535. GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  536. return FAILED;
  537. }
  538. OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  539. if (peer_data_anchor == nullptr) {
  540. REPORT_INNER_ERROR("E19999", "Index:%u in data anchor of node:%s(%s), its peer ahcnhor is nullptr, check invalid",
  541. SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str());
  542. GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  543. return FAILED;
  544. }
  545. NodePtr data_input = peer_data_anchor->GetOwnerNode();
  546. // Remove SwitchN data input
  547. if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  548. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed",
  549. data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(),
  550. switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), SWITCH_DATA_INPUT);
  551. GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(),
  552. switch_n_node->GetName().c_str());
  553. return FAILED;
  554. }
  555. if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) {
  556. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  557. data_input->GetName().c_str(), data_input->GetType().c_str(),
  558. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  559. GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(),
  560. switch_case->GetName().c_str());
  561. return FAILED;
  562. }
  563. // Add SwitchCase control output
  564. for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) {
  565. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  566. NodePtr data_output = peer_in_anchor->GetOwnerNode();
  567. if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) ||
  568. (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) {
  569. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) or "
  570. "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  571. switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), out_data_anchor->GetIdx(),
  572. data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx(),
  573. data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(),
  574. data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx());
  575. GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(),
  576. switch_n_node->GetName().c_str(), data_output->GetName().c_str());
  577. return FAILED;
  578. }
  579. if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) {
  580. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  581. switch_case->GetName().c_str(), switch_case->GetType().c_str(),
  582. data_output->GetName().c_str(), data_output->GetType().c_str());
  583. GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(),
  584. data_output->GetName().c_str());
  585. return FAILED;
  586. }
  587. }
  588. }
  589. GE_CHK_STATUS_RET(MoveCtrlEdges(switch_n_node, switch_case), "Move ctrl edges failed.");
  590. bypass_nodes_.emplace_back(switch_n_node);
  591. GELOGI("Bypass SwitchN node %s success.", switch_n_node->GetName().c_str());
  592. return SUCCESS;
  593. }
  594. ///
  595. /// @brief Attach stream_label & batch_label for batch branch
  596. /// @param [in] switch_case_node
  597. /// @return Status
  598. ///
  599. Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) {
  600. std::vector<std::string> stream_label_list;
  601. for (uint32_t i = 0; i < static_cast<uint32_t>(batch_head_nodes_.size()); i++) {
  602. if (AttachBatchLabel(i) != SUCCESS) {
  603. GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i);
  604. return FAILED;
  605. }
  606. const std::string &stream_label = "stream_label_batch_" + std::to_string(i);
  607. if (AttachStreamLabel(i, stream_label) != SUCCESS) {
  608. GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str());
  609. return FAILED;
  610. }
  611. stream_label_list.emplace_back(stream_label);
  612. }
  613. return switch_case_node == nullptr ? SUCCESS : SetActiveLabelList(switch_case_node, stream_label_list);
  614. }
  615. ///
  616. /// @brief Attach batch_label for batch branch
  617. /// @param [in] batch_idx
  618. /// @return Status
  619. ///
  620. Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) {
  621. std::stack<NodePtr> nodes;
  622. for (const auto &node : batch_head_nodes_[batch_idx]) {
  623. nodes.push(node);
  624. }
  625. const std::string &batch_label = "Batch_" + std::to_string(batch_idx);
  626. std::unordered_set<NodePtr> handled_nodes;
  627. while (!nodes.empty()) {
  628. NodePtr cur_node = nodes.top();
  629. nodes.pop();
  630. if (handled_nodes.count(cur_node) > 0) {
  631. continue;
  632. }
  633. OpDescPtr cur_desc = cur_node->GetOpDesc();
  634. GE_CHECK_NOTNULL(cur_desc);
  635. if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
  636. std::string tmp_label;
  637. if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) {
  638. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(),
  639. cur_desc->GetName().c_str(), cur_desc->GetType().c_str());
  640. GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str());
  641. return FAILED;
  642. }
  643. if (tmp_label != batch_label) {
  644. REPORT_INNER_ERROR("E19999", "Attr:%s from op:%s(%s) value:%s not equal to expect:%s, check invalid",
  645. ATTR_NAME_BATCH_LABEL.c_str(), cur_desc->GetName().c_str(), cur_desc->GetType().c_str(),
  646. tmp_label.c_str(), batch_label.c_str());
  647. GELOGE(FAILED, "Reach other batch_branch, node:%s, cur_label:%s, batch_label:%s.", cur_desc->GetName().c_str(),
  648. tmp_label.c_str(), batch_label.c_str());
  649. return FAILED;
  650. }
  651. }
  652. GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str());
  653. if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  654. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(),
  655. cur_desc->GetName().c_str(), cur_desc->GetType().c_str());
  656. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str());
  657. return FAILED;
  658. }
  659. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  660. OpDescPtr op_desc = out_node->GetOpDesc();
  661. GE_CHECK_NOTNULL(op_desc);
  662. const std::string &type = op_desc->GetType();
  663. if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  664. continue;
  665. }
  666. if (type == NETOUTPUT) {
  667. REPORT_CALL_ERROR("E19999", "SReach net_output without Merge, cur_node:%s(%s), check invalid",
  668. cur_node->GetName().c_str(), cur_node->GetType().c_str());
  669. GELOGE(FAILED, "Reach net_output without Merge, cur_node:%s.", cur_node->GetName().c_str());
  670. return FAILED;
  671. }
  672. nodes.push(out_node);
  673. }
  674. (void)handled_nodes.insert(cur_node);
  675. }
  676. return SUCCESS;
  677. }
  678. ///
  679. /// @brief Attach stream_label for batch branch
  680. /// @param [in] batch_idx
  681. /// @param [in] stream_label
  682. /// @return Status
  683. ///
  684. Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) {
  685. std::stack<NodePtr> nodes;
  686. for (const auto &node : batch_head_nodes_[batch_idx]) {
  687. nodes.push(node);
  688. }
  689. std::unordered_set<NodePtr> handled_nodes;
  690. while (!nodes.empty()) {
  691. NodePtr cur_node = nodes.top();
  692. nodes.pop();
  693. OpDescPtr cur_desc = cur_node->GetOpDesc();
  694. GE_CHECK_NOTNULL(cur_desc);
  695. if ((handled_nodes.count(cur_node) > 0) || (cur_desc->HasAttr(ATTR_NAME_STREAM_LABEL))) {
  696. continue;
  697. }
  698. GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str());
  699. if (SetStreamLabel(cur_node, stream_label) != SUCCESS) {
  700. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  701. stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str());
  702. GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str());
  703. return FAILED;
  704. }
  705. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  706. nodes.push(out_node);
  707. }
  708. (void)handled_nodes.insert(cur_node);
  709. }
  710. return SUCCESS;
  711. }
  712. ///
  713. /// @brief move edges from old_node to new_node
  714. /// @param [in] old_node
  715. /// @param [in] new_node
  716. /// @return Status
  717. ///
  718. Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  719. if (old_node == new_node) {
  720. return SUCCESS;
  721. }
  722. for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) {
  723. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()),
  724. "Merge remove in ctrl edge failed.");
  725. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()),
  726. "StreamMerge add in ctrl edge failed.");
  727. }
  728. for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) {
  729. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  730. "Merge remove out ctrl edge failed.");
  731. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  732. "StreamMerge add out ctrl edge failed.");
  733. }
  734. return SUCCESS;
  735. }
  736. ///
  737. /// @brief attach stream_label & batch_label without change structure of graph
  738. /// @param [in] batch_num
  739. /// @return void
  740. ///
  741. Status MultiBatchPass::AttachLabelOnly(uint32_t batch_num) {
  742. std::vector<NodePtr> output_nodes;
  743. for (uint32_t i = 0; i < batch_num; i++) {
  744. output_nodes.clear();
  745. for (const NodePtr &node : switch_n_nodes_) {
  746. // idx is promised to be valid
  747. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  748. GE_CHECK_NOTNULL(out_data_anchor);
  749. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  750. output_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  751. }
  752. }
  753. batch_head_nodes_.emplace_back(output_nodes);
  754. }
  755. return AttachLabel(nullptr);
  756. }
  757. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示