You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_pass.cc 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_pass.h"
  17. #include <stack>
  18. #include <unordered_set>
  19. #include "common/ge/ge_util.h"
  20. #include "graph/common/omg_util.h"
  21. #include "graph/utils/type_utils.h"
  22. using std::string;
  23. using std::vector;
  24. namespace ge {
  25. Status MultiBatchPass::Run(ComputeGraphPtr graph) {
  26. GELOGD("MultiBatchPass Enter");
  27. if (graph->GetParentGraph() != nullptr) {
  28. GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str());
  29. return SUCCESS;
  30. }
  31. OutDataAnchorPtr pred_value = nullptr;
  32. Status ret = FindPredValue(graph, pred_value);
  33. if (ret == NOT_CHANGED) {
  34. GELOGI("SwitchN node not exist, graph not changed.");
  35. return SUCCESS;
  36. }
  37. if (ret != SUCCESS) {
  38. GELOGE(FAILED, "FindPredValue failed.");
  39. return FAILED;
  40. }
  41. if (GetDynamicType() != SUCCESS) {
  42. GELOGE(FAILED, "Get dynamic type failed.");
  43. return FAILED;
  44. }
  45. if (GetUserDesignateShape() != SUCCESS) {
  46. GELOGE(FAILED, "Get user designate shape failed.");
  47. return FAILED;
  48. }
  49. std::vector<std::vector<int64_t>> batch_shape;
  50. vector<vector<int64_t>> combined_batch;
  51. if (!CheckSwitchN(batch_shape, combined_batch)) {
  52. GELOGE(FAILED, "CheckSwitchN failed.");
  53. return FAILED;
  54. }
  55. if (attach_label_only_) {
  56. return AttachLabelOnly(batch_shape.size());
  57. }
  58. if (FindSwitchOutNodes(batch_shape.size()) != SUCCESS) {
  59. GELOGE(FAILED, "Find SwitchN out nodes failed.");
  60. return FAILED;
  61. }
  62. if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) {
  63. GELOGE(FAILED, "Replace SwitchN nodes failed.");
  64. return FAILED;
  65. }
  66. for (const NodePtr &node : bypass_nodes_) {
  67. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  68. GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str());
  69. return FAILED;
  70. }
  71. }
  72. GELOGD("MultiBatchPass Leave");
  73. return SUCCESS;
  74. }
  75. ///
  76. /// @brief Clear Status
  77. /// @return
  78. ///
  79. Status MultiBatchPass::ClearStatus() {
  80. switch_n_nodes_.clear();
  81. bypass_nodes_.clear();
  82. batch_head_nodes_.clear();
  83. return SUCCESS;
  84. }
  85. ///
  86. /// @brief Replace & Combine SwitchN nodes
  87. /// @param [in] graph
  88. /// @param [out] pred_value
  89. /// @return Status
  90. ///
  91. Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) {
  92. for (const NodePtr &node : graph->GetDirectNode()) {
  93. if (node->GetType() != SWITCHN) {
  94. continue;
  95. }
  96. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  97. if (in_data_anchor == nullptr) {
  98. GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
  99. return FAILED;
  100. }
  101. OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor();
  102. if (pred_input == nullptr) {
  103. GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
  104. return FAILED;
  105. }
  106. if (pred_value == nullptr) {
  107. pred_value = pred_input;
  108. } else if (pred_value != pred_input) {
  109. GELOGE(FAILED, "Multi pred_value node exist.");
  110. return FAILED;
  111. }
  112. switch_n_nodes_.emplace_back(node);
  113. }
  114. if (switch_n_nodes_.empty()) {
  115. GELOGI("SwitchN node not exist.");
  116. return NOT_CHANGED;
  117. }
  118. if (pred_value == nullptr) {
  119. GELOGE(FAILED, "FindPredInput failed, pred_value is null.");
  120. return FAILED;
  121. }
  122. GELOGI("Find pred_value %s.", pred_value->GetOwnerNode()->GetName().c_str());
  123. return SUCCESS;
  124. }
  125. ///
  126. /// @brief Get dynamic type: dynamic batch size: 1, dynamic image size: 2, dynamic dims: 3
  127. /// @return Status
  128. ///
  129. Status MultiBatchPass::GetDynamicType() {
  130. for (const auto &switchn : switch_n_nodes_) {
  131. auto switchn_desc = switchn->GetOpDesc();
  132. GE_CHECK_NOTNULL(switchn_desc);
  133. int32_t dynamic_type = static_cast<int32_t>(FIXED);
  134. if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) {
  135. GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str());
  136. return FAILED;
  137. }
  138. if (dynamic_type == static_cast<int32_t>(FIXED)) {
  139. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  140. return FAILED;
  141. }
  142. if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
  143. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.",
  144. dynamic_type, dynamic_type_);
  145. return FAILED;
  146. }
  147. dynamic_type_ = dynamic_type;
  148. }
  149. if (dynamic_type_ == static_cast<int32_t>(FIXED)) {
  150. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  151. return FAILED;
  152. }
  153. return SUCCESS;
  154. }
  155. ///
  156. /// @brief Get user designate shape order. eg{"data","label","mask"}
  157. /// @return Status
  158. ///
  159. Status MultiBatchPass::GetUserDesignateShape() {
  160. data_name_order_.clear();
  161. bool first_check = true;
  162. for (const auto &switchn : switch_n_nodes_) {
  163. auto switchn_desc = switchn->GetOpDesc();
  164. GE_CHECK_NOTNULL(switchn_desc);
  165. vector<string> cur_switchn_data_name_order;
  166. if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) {
  167. GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str());
  168. return FAILED;
  169. }
  170. if (first_check) {
  171. data_name_order_ = cur_switchn_data_name_order;
  172. first_check = false;
  173. } else {
  174. if (data_name_order_ != cur_switchn_data_name_order) {
  175. GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
  176. switchn->GetName().c_str());
  177. return FAILED;
  178. }
  179. }
  180. }
  181. if (data_name_order_.empty()) {
  182. GELOGE(FAILED, "user shape order can not be empty");
  183. return FAILED;
  184. }
  185. return SUCCESS;
  186. }
  187. ///
  188. /// @brief Check SwitchN nodes
  189. /// @param [out] batch_shape
  190. /// @param [out] combined_batch
  191. /// @return bool
  192. ///
  193. bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) {
  194. // Check if output_num of different SwitchN is same
  195. uint32_t batch_num = 0;
  196. for (const NodePtr &node : switch_n_nodes_) {
  197. uint32_t tmp_num = node->GetAllOutDataAnchorsSize();
  198. if (batch_num == 0) {
  199. batch_num = tmp_num;
  200. } else if (batch_num != tmp_num) {
  201. GELOGE(FAILED, "Output size of SwitchN not equal;");
  202. return false;
  203. }
  204. }
  205. if (!GetBatchInfo(batch_num, batch_shape, combined_batch)) {
  206. GELOGE(FAILED, "Get batch info failed.");
  207. return false;
  208. }
  209. if (batch_shape.empty()) {
  210. GELOGE(FAILED, "batch_shape is empty.");
  211. return false;
  212. }
  213. if (combined_batch.empty()) {
  214. GELOGE(FAILED, "combined_batch is empty.");
  215. return false;
  216. }
  217. size_t dim_num = batch_shape[0].size();
  218. size_t combined_dim_num = combined_batch[0].size();
  219. for (uint32_t i = 1; i < batch_num; i++) {
  220. size_t tmp_dim_num = batch_shape[i].size();
  221. if (dim_num != tmp_dim_num) {
  222. GELOGE(FAILED, "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
  223. return false;
  224. }
  225. size_t tmp_combined_dim_num = combined_batch[i].size();
  226. if (combined_dim_num != tmp_combined_dim_num) {
  227. GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
  228. return false;
  229. }
  230. }
  231. return true;
  232. }
  233. ///
  234. /// @brief Check SwitchN nodes
  235. /// @param [in] batch_num
  236. /// @param [out] batch_shape
  237. /// @param [out] combined_batch
  238. /// @return bool
  239. ///
  240. bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape,
  241. vector<vector<int64_t>> &combined_batch) {
  242. // Check if output_shape of different SwitchN is same
  243. vector<vector<int64_t>> idx_batch_shape;
  244. vector<vector<int64_t>> idx_combined_batch;
  245. for (uint32_t i = 0; i < batch_num; i++) {
  246. idx_batch_shape.clear();
  247. idx_combined_batch.clear();
  248. for (const NodePtr &node : switch_n_nodes_) {
  249. OpDescPtr op_desc = node->GetOpDesc();
  250. if (op_desc == nullptr) {
  251. GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
  252. return false;
  253. }
  254. vector<int64_t> output_dims;
  255. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
  256. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
  257. return false;
  258. }
  259. idx_batch_shape.emplace_back(output_dims);
  260. output_dims.clear();
  261. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) {
  262. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_COMBINED_DYNAMIC_DIMS failed, batch_index=%u.", i);
  263. return false;
  264. }
  265. idx_combined_batch.emplace_back(output_dims);
  266. }
  267. if (!CheckDims(idx_batch_shape)) {
  268. GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i);
  269. return false;
  270. }
  271. batch_shape.emplace_back(idx_batch_shape[0]);
  272. combined_batch.emplace_back(idx_combined_batch[0]);
  273. }
  274. return true;
  275. }
  276. ///
  277. /// @brief Find outputs of SwitchN nodes
  278. /// @param [in] batch_num
  279. /// @return void
  280. ///
  281. Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
  282. std::vector<NodePtr> output_nodes;
  283. for (uint32_t i = 0; i < batch_num; i++) {
  284. output_nodes.clear();
  285. for (const NodePtr &node : switch_n_nodes_) {
  286. // idx is promised to be valid
  287. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  288. GE_CHECK_NOTNULL(out_data_anchor);
  289. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  290. auto out_node = peer_in_anchor->GetOwnerNode();
  291. if (out_node->GetType() != IDENTITY || !out_node->GetOutDataNodes().empty()) {
  292. output_nodes.emplace_back(out_node);
  293. continue;
  294. }
  295. bypass_nodes_.emplace_back(out_node);
  296. if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  297. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  298. out_node->GetName().c_str());
  299. return FAILED;
  300. }
  301. for (auto &identity_out_node : out_node->GetOutControlNodes()) {
  302. output_nodes.emplace_back(identity_out_node);
  303. if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) !=
  304. GRAPH_SUCCESS) {
  305. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  306. out_node->GetName().c_str());
  307. return FAILED;
  308. }
  309. }
  310. }
  311. }
  312. batch_head_nodes_.emplace_back(output_nodes);
  313. }
  314. return SUCCESS;
  315. }
  316. ///
  317. /// @brief Replace & Combine SwitchN nodes
  318. /// @param [in] graph
  319. /// @param [in] pred_value
  320. /// @param [in] batch_shape
  321. /// @param [in] combined_batch
  322. /// @return Status
  323. ///
  324. Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
  325. const vector<vector<int64_t>> &batch_shape,
  326. const vector<vector<int64_t>> &combined_batch) {
  327. NodePtr pred_value_node = pred_value->GetOwnerNode();
  328. // Create SwitchCase node
  329. const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
  330. NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape, combined_batch);
  331. if (switch_case == nullptr) {
  332. GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str());
  333. return FAILED;
  334. }
  335. for (const NodePtr &switch_n_node : switch_n_nodes_) {
  336. if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) {
  337. GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str());
  338. return FAILED;
  339. }
  340. }
  341. // Add switchCase input edge
  342. if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  343. GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(),
  344. switch_case->GetName().c_str());
  345. return FAILED;
  346. }
  347. if (AttachLabel(switch_case) != SUCCESS) {
  348. GELOGE(FAILED, "AttachLabel failed.");
  349. return FAILED;
  350. }
  351. return SUCCESS;
  352. }
  353. ///
  354. /// @brief Check if output_shape of different SwitchN is same
  355. /// @param [in] output_shape
  356. /// @return bool
  357. ///
  358. bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_shape) const {
  359. if (output_shape.empty()) {
  360. GELOGE(FAILED, "CheckDims failed: output_shape is empty.");
  361. return false;
  362. }
  363. size_t num = output_shape.size();
  364. size_t dim_num = output_shape[0].size();
  365. for (size_t i = 1; i < num; i++) {
  366. size_t tmp_dim_num = output_shape[i].size();
  367. if (dim_num != tmp_dim_num) {
  368. GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num);
  369. return false;
  370. }
  371. }
  372. if (dim_num == 0) {
  373. return true;
  374. }
  375. for (size_t i = 0; i < dim_num; i++) {
  376. int64_t dim_value = output_shape[0][i];
  377. for (size_t j = 1; j < num; j++) {
  378. int64_t tmp_dim_value = output_shape[j][i];
  379. if (dim_value != tmp_dim_value) {
  380. GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i,
  381. dim_value, j, tmp_dim_value);
  382. return false;
  383. }
  384. }
  385. }
  386. return true;
  387. }
  388. ///
  389. /// @brief Create StreamSwitchN node
  390. /// @param [in] graph
  391. /// @param [in] name
  392. /// @param [in] pred_value
  393. /// @param [in] batch_shape
  394. /// @param [in] combined_batch
  395. /// @return ge::NodePtr
  396. ///
  397. NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
  398. const OutDataAnchorPtr &pred_value,
  399. const vector<vector<int64_t>> &batch_shape,
  400. const vector<vector<int64_t>> &combined_batch) {
  401. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
  402. if (op_desc == nullptr) {
  403. GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
  404. return nullptr;
  405. }
  406. GELOGI("Create StreamSwitchN op:%s.", name.c_str());
  407. OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc();
  408. if (pred_desc == nullptr) {
  409. GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str());
  410. return nullptr;
  411. }
  412. if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) {
  413. GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str());
  414. return nullptr;
  415. }
  416. NodePtr switch_case_node = graph->AddNode(op_desc);
  417. if (switch_case_node == nullptr) {
  418. GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str());
  419. return nullptr;
  420. }
  421. uint32_t batch_num = static_cast<uint32_t>(batch_shape.size());
  422. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  423. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str());
  424. return nullptr;
  425. }
  426. if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) {
  427. GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str());
  428. return nullptr;
  429. }
  430. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  431. GELOGE(FAILED, "Set attr ATTR_USER_DESIGNEATE_SHAPE_ORDER failed, StreamSwitchN:%s.", name.c_str());
  432. return nullptr;
  433. }
  434. for (uint32_t i = 0; i < batch_num; i++) {
  435. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  436. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) {
  437. GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
  438. return nullptr;
  439. }
  440. const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
  441. if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
  442. GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
  443. return nullptr;
  444. }
  445. }
  446. return switch_case_node;
  447. }
  448. ///
  449. /// @brief Bypass SwitchN node
  450. /// @param [in] switch_n_node
  451. /// @param [in] switch_case
  452. /// @return Status
  453. ///
  454. Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) {
  455. InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  456. if (in_data_anchor == nullptr) {
  457. GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  458. return FAILED;
  459. }
  460. OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  461. if (peer_data_anchor == nullptr) {
  462. GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  463. return FAILED;
  464. }
  465. NodePtr data_input = peer_data_anchor->GetOwnerNode();
  466. // Remove SwitchN data input
  467. if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  468. GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(),
  469. switch_n_node->GetName().c_str());
  470. return FAILED;
  471. }
  472. if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) {
  473. GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(),
  474. switch_case->GetName().c_str());
  475. return FAILED;
  476. }
  477. // Add SwitchCase control output
  478. for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) {
  479. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  480. NodePtr data_output = peer_in_anchor->GetOwnerNode();
  481. if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) ||
  482. (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) {
  483. GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(),
  484. switch_n_node->GetName().c_str(), data_output->GetName().c_str());
  485. return FAILED;
  486. }
  487. if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) {
  488. GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(),
  489. data_output->GetName().c_str());
  490. return FAILED;
  491. }
  492. }
  493. }
  494. GE_CHK_STATUS_RET(MoveCtrlEdges(switch_n_node, switch_case), "Move ctrl edges failed.");
  495. bypass_nodes_.emplace_back(switch_n_node);
  496. GELOGI("Bypass SwitchN node %s success.", switch_n_node->GetName().c_str());
  497. return SUCCESS;
  498. }
  499. ///
  500. /// @brief Attach stream_label & batch_label for batch branch
  501. /// @param [in] switch_case_node
  502. /// @return Status
  503. ///
  504. Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) {
  505. std::vector<std::string> stream_label_list;
  506. for (uint32_t i = 0; i < static_cast<uint32_t>(batch_head_nodes_.size()); i++) {
  507. if (AttachBatchLabel(i) != SUCCESS) {
  508. GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i);
  509. return FAILED;
  510. }
  511. const std::string &stream_label = "stream_label_batch_" + std::to_string(i);
  512. if (AttachStreamLabel(i, stream_label) != SUCCESS) {
  513. GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str());
  514. return FAILED;
  515. }
  516. stream_label_list.emplace_back(stream_label);
  517. }
  518. return switch_case_node == nullptr ? SUCCESS : SetActiveLabelList(switch_case_node, stream_label_list);
  519. }
  520. ///
  521. /// @brief Attach batch_label for batch branch
  522. /// @param [in] batch_idx
  523. /// @return Status
  524. ///
  525. Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) {
  526. std::stack<NodePtr> nodes;
  527. for (const auto &node : batch_head_nodes_[batch_idx]) {
  528. nodes.push(node);
  529. }
  530. const std::string &batch_label = "Batch_" + std::to_string(batch_idx);
  531. std::unordered_set<NodePtr> handled_nodes;
  532. while (!nodes.empty()) {
  533. NodePtr cur_node = nodes.top();
  534. nodes.pop();
  535. if (handled_nodes.count(cur_node) > 0) {
  536. continue;
  537. }
  538. OpDescPtr cur_desc = cur_node->GetOpDesc();
  539. GE_CHECK_NOTNULL(cur_desc);
  540. if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
  541. std::string tmp_label;
  542. if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) {
  543. GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str());
  544. return FAILED;
  545. }
  546. if (tmp_label != batch_label) {
  547. GELOGE(FAILED, "Reach other batch_branch, node:%s, cur_label:%s, batch_label:%s.", cur_desc->GetName().c_str(),
  548. tmp_label.c_str(), batch_label.c_str());
  549. return FAILED;
  550. }
  551. }
  552. GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str());
  553. if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  554. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str());
  555. return FAILED;
  556. }
  557. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  558. OpDescPtr op_desc = out_node->GetOpDesc();
  559. GE_CHECK_NOTNULL(op_desc);
  560. const std::string &type = op_desc->GetType();
  561. if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  562. continue;
  563. }
  564. if (type == NETOUTPUT) {
  565. GELOGE(FAILED, "Reach net_output without Merge, cur_node:%s.", cur_node->GetName().c_str());
  566. return FAILED;
  567. }
  568. nodes.push(out_node);
  569. }
  570. (void)handled_nodes.insert(cur_node);
  571. }
  572. return SUCCESS;
  573. }
  574. ///
  575. /// @brief Attach stream_label for batch branch
  576. /// @param [in] batch_idx
  577. /// @param [in] stream_label
  578. /// @return Status
  579. ///
  580. Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) {
  581. std::stack<NodePtr> nodes;
  582. for (const auto &node : batch_head_nodes_[batch_idx]) {
  583. nodes.push(node);
  584. }
  585. std::unordered_set<NodePtr> handled_nodes;
  586. while (!nodes.empty()) {
  587. NodePtr cur_node = nodes.top();
  588. nodes.pop();
  589. OpDescPtr cur_desc = cur_node->GetOpDesc();
  590. GE_CHECK_NOTNULL(cur_desc);
  591. if ((handled_nodes.count(cur_node) > 0) || (cur_desc->HasAttr(ATTR_NAME_STREAM_LABEL))) {
  592. continue;
  593. }
  594. GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str());
  595. if (SetStreamLabel(cur_node, stream_label) != SUCCESS) {
  596. GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str());
  597. return FAILED;
  598. }
  599. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  600. nodes.push(out_node);
  601. }
  602. (void)handled_nodes.insert(cur_node);
  603. }
  604. return SUCCESS;
  605. }
  606. ///
  607. /// @brief move edges from old_node to new_node
  608. /// @param [in] old_node
  609. /// @param [in] new_node
  610. /// @return Status
  611. ///
  612. Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  613. if (old_node == new_node) {
  614. return SUCCESS;
  615. }
  616. for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) {
  617. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()),
  618. "Merge remove in ctrl edge failed.");
  619. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()),
  620. "StreamMerge add in ctrl edge failed.");
  621. }
  622. for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) {
  623. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  624. "Merge remove out ctrl edge failed.");
  625. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  626. "StreamMerge add out ctrl edge failed.");
  627. }
  628. return SUCCESS;
  629. }
  630. ///
  631. /// @brief attach stream_label & batch_label without change structure of graph
  632. /// @param [in] batch_num
  633. /// @return void
  634. ///
  635. Status MultiBatchPass::AttachLabelOnly(uint32_t batch_num) {
  636. std::vector<NodePtr> output_nodes;
  637. for (uint32_t i = 0; i < batch_num; i++) {
  638. output_nodes.clear();
  639. for (const NodePtr &node : switch_n_nodes_) {
  640. // idx is promised to be valid
  641. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  642. GE_CHECK_NOTNULL(out_data_anchor);
  643. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  644. output_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  645. }
  646. }
  647. batch_head_nodes_.emplace_back(output_nodes);
  648. }
  649. return AttachLabel(nullptr);
  650. }
  651. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示