You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

for_pass.cc 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/for_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/op/ge_op_utils.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/common/types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. namespace {
  29. const uint32_t kWhileIInputIndex = 0;
  30. const uint32_t kWhileAbsDeltaInputIndex = 1;
  31. const uint32_t kWhileRangeInputIndex = 2;
  32. const uint32_t kWhileStartInputIndex = 3;
  33. const uint32_t kWhileDeltaInputIndex = 4;
  34. const uint32_t kWhileDataInputIndex = 5;
  35. const uint32_t kSubgraphLoopVarInputIndex = 0;
  36. const uint32_t kSubgraphInputIndex = 1;
  37. const uint32_t kWhileOutputIndex = 5;
  38. const std::string kAbs = "Abs";
  39. } // namespace
  40. namespace ge {
  41. Status ForPass::Run(NodePtr &node) {
  42. if (node->GetType() != FOR) {
  43. GELOGD("no need for_pass for node %s.", node->GetName().c_str());
  44. return SUCCESS;
  45. }
  46. GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
  47. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  48. GE_CHECK_NOTNULL(graph);
  49. ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
  50. GE_CHECK_NOTNULL(root_graph);
  51. ForInfo for_info;
  52. GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info), "Build ForInfo failed, node:%s.",
  53. node->GetName().c_str());
  54. WhileInfo while_info;
  55. GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info), "Transfer WhileInfo from ForInfo failed, node:%s.",
  56. node->GetName().c_str());
  57. ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
  58. if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
  59. GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str());
  60. return FAILED;
  61. }
  62. ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
  63. if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
  64. GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str());
  65. return FAILED;
  66. }
  67. GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info), "Update InputMapping for for-body-graph failed, node:%s.",
  68. node->GetName().c_str());
  69. // for node has and only has one subgraph
  70. GE_CHECK_NOTNULL(node->GetOpDesc());
  71. node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
  72. GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
  73. return IsolateAndDeleteNode(node, std::vector<int>());
  74. }
  75. ///
  76. /// @brief Build for_info
  77. /// @param [in] root_graph
  78. /// @param [in] node
  79. /// @param [out] for_info
  80. /// @return Status
  81. ///
  82. Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
  83. GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
  84. OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
  85. OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
  86. OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
  87. if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
  88. GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str());
  89. return FAILED;
  90. }
  91. std::vector<OutDataAnchorPtr> data_inputs;
  92. std::vector<std::vector<InDataAnchorPtr>> data_outputs;
  93. std::vector<OutControlAnchorPtr> ctrl_inputs;
  94. std::vector<InControlAnchorPtr> ctrl_outputs;
  95. if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
  96. GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str());
  97. return FAILED;
  98. }
  99. NodeUtils::UnlinkAll(*node);
  100. OpDescPtr op_desc = node->GetOpDesc();
  101. GE_CHECK_NOTNULL(op_desc);
  102. // For node has and only has one sub_graph
  103. std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
  104. if (for_body_name.empty()) {
  105. GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str());
  106. return FAILED;
  107. }
  108. ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
  109. if (for_body == nullptr) {
  110. GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str());
  111. return FAILED;
  112. }
  113. for_info.for_node = node;
  114. for_info.start = start;
  115. for_info.limit = limit;
  116. for_info.delta = delta;
  117. for_info.body_name = for_body_name;
  118. for_info.for_body = for_body;
  119. for_info.data_inputs = std::move(data_inputs);
  120. for_info.data_outputs = std::move(data_outputs);
  121. for_info.ctrl_inputs = std::move(ctrl_inputs);
  122. for_info.ctrl_outputs = std::move(ctrl_outputs);
  123. GELOGI("Build for_info for node %s succ.", node->GetName().c_str());
  124. return SUCCESS;
  125. }
  126. ///
  127. /// @brief Find input with index for For node
  128. /// @param [in] node
  129. /// @param [in] index
  130. /// @return OutDataAnchorPtr
  131. ///
  132. OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
  133. if (node == nullptr) {
  134. GELOGE(FAILED, "FindInputWithIndex failed: node is NULL.");
  135. return nullptr;
  136. }
  137. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  138. if (in_data_anchor == nullptr) {
  139. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  140. return nullptr;
  141. }
  142. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  143. if (peer_out_anchor == nullptr) {
  144. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index);
  145. return nullptr;
  146. }
  147. return peer_out_anchor;
  148. }
  149. ///
  150. /// @brief Find inputs / outputs for for node
  151. /// @param [in] node
  152. /// @param [out] data_inputs
  153. /// @param [out] data_outputs
  154. /// @param [out] ctrl_inputs
  155. /// @param [out] ctrl_outputs
  156. /// @return Status
  157. ///
  158. Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
  159. std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
  160. std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
  161. std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
  162. GE_CHECK_NOTNULL(node);
  163. uint32_t input_data_num = node->GetAllInDataAnchorsSize();
  164. for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
  165. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  166. if (in_data_anchor == nullptr) {
  167. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  168. return FAILED;
  169. }
  170. GE_IF_BOOL_EXEC(
  171. in_data_anchor->GetPeerOutAnchor() == nullptr,
  172. GELOGW("Get null input by index %d from node %s ", in_data_anchor->GetIdx(), node->GetName().c_str());
  173. continue);
  174. data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
  175. }
  176. for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  177. std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
  178. for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  179. peer_in_data_anchors.emplace_back(peer_in_data_anchor);
  180. }
  181. data_outputs.emplace_back(peer_in_data_anchors);
  182. }
  183. InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
  184. GE_CHECK_NOTNULL(in_ctrl_anchor);
  185. for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  186. ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
  187. }
  188. OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
  189. GE_CHECK_NOTNULL(out_ctrl_anchor);
  190. for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  191. ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
  192. }
  193. return SUCCESS;
  194. }
  195. ///
  196. /// @brief Transfer while_info from for_info
  197. /// @param [in] graph
  198. /// @param [in] for_info
  199. /// @param [out] while_info
  200. /// @return Status
  201. ///
  202. Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
  203. std::string for_name = for_info.for_node->GetName();
  204. GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
  205. std::string i_name = for_name + "_i";
  206. NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
  207. if (i_node == nullptr) {
  208. GELOGE(FAILED, "TranWhileInfo failed: create i_node failed.");
  209. return FAILED;
  210. }
  211. AddRePassNode(i_node);
  212. std::string identity_name = i_name + "_Identity";
  213. NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
  214. // Const node has and only has one output, Identity node has and only has one input
  215. if ((identity_node == nullptr) ||
  216. (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
  217. GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str());
  218. return FAILED;
  219. }
  220. AddRePassNode(identity_node);
  221. // Identity node has and only has one output
  222. OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
  223. if (i_input == nullptr) {
  224. GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL.");
  225. return FAILED;
  226. }
  227. OutDataAnchorPtr range_input = nullptr;
  228. OutDataAnchorPtr abs_delta_input = nullptr;
  229. if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
  230. GELOGE(FAILED, "TranWhileInfo failed: create loop input failed.");
  231. return FAILED;
  232. }
  233. BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
  234. if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
  235. GELOGE(FAILED, "TranWhileInfo failed: insert while node failed.");
  236. return FAILED;
  237. }
  238. GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.", for_name.c_str(),
  239. while_info.while_node->GetName().c_str());
  240. return SUCCESS;
  241. }
  242. ///
  243. /// @brief Create const op_desc
  244. /// @param [in] name
  245. /// @param [in] value
  246. /// @return OpDescPtr
  247. ///
  248. OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
  249. OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
  250. if (const_op_desc == nullptr) {
  251. GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str());
  252. return nullptr;
  253. }
  254. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
  255. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  256. if (const_value == nullptr) {
  257. GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str());
  258. return nullptr;
  259. }
  260. if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  261. GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str());
  262. return nullptr;
  263. }
  264. if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
  265. GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str());
  266. return nullptr;
  267. }
  268. return const_op_desc;
  269. }
  270. ///
  271. /// @brief Create loop node
  272. /// @param [in] graph
  273. /// @param [in] for_info
  274. /// @param [out] range_input
  275. /// @param [out] abs_delta_input
  276. /// @return Status
  277. ///
  278. Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input,
  279. OutDataAnchorPtr &abs_delta_input) {
  280. std::string for_name = for_info.for_node->GetName();
  281. GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
  282. OutDataAnchorPtr start = for_info.start;
  283. OutDataAnchorPtr limit = for_info.limit;
  284. OutDataAnchorPtr delta = for_info.delta;
  285. std::string sub_name_0 = for_name + "_Sub_0";
  286. std::string abs_name_0 = for_name + "_Abs_0";
  287. std::string abs_name_1 = for_name + "_Abs_1";
  288. // i * |delta| < |limit-start|
  289. PartialGraphBuilder graph_builder;
  290. graph_builder.SetOwnerGraph(graph)
  291. .AddExistNode(for_info.start->GetOwnerNode())
  292. .AddExistNode(for_info.limit->GetOwnerNode())
  293. .AddExistNode(for_info.delta->GetOwnerNode())
  294. .AddNode(CreateOpDesc(sub_name_0, SUB, false))
  295. .AddNode(CreateOpDesc(abs_name_0, kAbs, true))
  296. .AddNode(CreateOpDesc(abs_name_1, kAbs, true))
  297. .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
  298. .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
  299. .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
  300. .AddDataLink(sub_name_0, 0, abs_name_1, 0);
  301. graphStatus error_code = GRAPH_SUCCESS;
  302. std::string error_msg;
  303. if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
  304. GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  305. return FAILED;
  306. }
  307. // Add repass_nodes
  308. for (auto &node : graph_builder.GetAllNodes()) {
  309. AddRePassNode(node);
  310. }
  311. NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
  312. NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
  313. if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
  314. GELOGE(FAILED, "Create loop node failed: node is NULL.");
  315. return FAILED;
  316. }
  317. GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
  318. // abs_node has and only has one output
  319. abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
  320. range_input = loop_count_node->GetOutDataAnchor(0);
  321. return SUCCESS;
  322. }
  323. ///
  324. /// @brief Create op_desc
  325. /// @param [in] name
  326. /// @param [in] type
  327. /// @param [in] io_equal_flag
  328. /// @return OpDescPtr
  329. ///
  330. OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
  331. OpDescBuilder op_desc_builder(name, type);
  332. if (io_equal_flag) {
  333. op_desc_builder.AddInput("x").AddOutput("y");
  334. } else {
  335. op_desc_builder.AddInput("x1").AddInput("x2").AddOutput("y");
  336. }
  337. return op_desc_builder.Build();
  338. }
  339. ///
  340. /// @brief Build while-info
  341. /// @param [in] for_info
  342. /// @param [in] i_input
  343. /// @param [in] range_input
  344. /// @param [in] abs_delta_input
  345. /// @param [out] while_info
  346. /// @return void
  347. ///
  348. void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
  349. const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
  350. WhileInfo &while_info) {
  351. while_info.i = i_input;
  352. while_info.abs_delta = abs_delta_input;
  353. while_info.range = range_input;
  354. while_info.start = for_info.start;
  355. while_info.delta = for_info.delta;
  356. while_info.for_body_name = for_info.body_name;
  357. while_info.for_body = for_info.for_body;
  358. while_info.data_inputs.emplace_back(while_info.i);
  359. while_info.data_inputs.emplace_back(while_info.abs_delta);
  360. while_info.data_inputs.emplace_back(while_info.range);
  361. while_info.data_inputs.emplace_back(while_info.start);
  362. while_info.data_inputs.emplace_back(while_info.delta);
  363. for (auto &item : for_info.data_inputs) {
  364. while_info.data_inputs.emplace_back(item);
  365. }
  366. for (auto &item : for_info.data_outputs) {
  367. while_info.data_outputs.emplace_back(item);
  368. }
  369. for (auto &item : for_info.ctrl_inputs) {
  370. while_info.ctrl_inputs.emplace_back(item);
  371. }
  372. for (auto &item : for_info.ctrl_outputs) {
  373. while_info.ctrl_outputs.emplace_back(item);
  374. }
  375. }
  376. ///
  377. /// @brief Insert while_node
  378. /// @param [in] graph
  379. /// @param [in] name
  380. /// @param [in&out] while_info
  381. /// @return Status
  382. ///
  383. Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
  384. GELOGD("Begin to create while node, name:%s.", name.c_str());
  385. size_t arg_num = while_info.data_inputs.size();
  386. OpDescBuilder op_desc_builder(name, WHILE);
  387. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
  388. if (op_desc == nullptr) {
  389. GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
  390. return FAILED;
  391. }
  392. NodePtr while_node = graph->AddNode(op_desc);
  393. if (while_node == nullptr) {
  394. GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str());
  395. return FAILED;
  396. }
  397. AddRePassNode(while_node);
  398. while_info.while_node = while_node;
  399. if (BuildWhileLink(while_info) != SUCCESS) {
  400. GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str());
  401. return FAILED;
  402. }
  403. GELOGD("Create while node succ, name:%s.", name.c_str());
  404. return SUCCESS;
  405. }
  406. ///
  407. /// @brief Build while link-edge
  408. /// @param [in] while_info
  409. /// @return Status
  410. ///
  411. Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
  412. NodePtr while_node = while_info.while_node;
  413. GE_CHECK_NOTNULL(while_node);
  414. size_t input_num = while_info.data_inputs.size();
  415. for (size_t i = 0; i < input_num; i++) {
  416. InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
  417. GE_CHECK_NOTNULL(in_data_anchor);
  418. OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
  419. if (peer_out_anchor == nullptr) {
  420. continue;
  421. }
  422. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor), "Add data-edge %s:%d->%s:%d failed.",
  423. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  424. while_node->GetName().c_str(), i);
  425. }
  426. size_t output_num = while_info.data_outputs.size();
  427. for (size_t i = 0; i < output_num; i++) {
  428. OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
  429. GE_CHECK_NOTNULL(out_data_anchor);
  430. for (auto &peer_in_anchor : while_info.data_outputs[i]) {
  431. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
  432. "Add data-edge %s:%d->%s:%d failed.", while_node->GetName().c_str(),
  433. i + kWhileOutputIndex, peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  434. peer_in_anchor->GetIdx());
  435. }
  436. }
  437. InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
  438. GE_CHECK_NOTNULL(in_ctrl_anchor);
  439. for (auto &peer_out_anchor : while_info.ctrl_inputs) {
  440. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor), "Add ctrl-edge %s->%s failed.",
  441. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  442. in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
  443. }
  444. OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
  445. GE_CHECK_NOTNULL(out_ctrl_anchor);
  446. for (auto &peer_in_anchor : while_info.ctrl_outputs) {
  447. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor), "Add ctrl-edge %s->%s failed.",
  448. out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  449. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  450. }
  451. return SUCCESS;
  452. }
  453. ///
  454. /// @brief Build cond_graph for while_node
  455. /// @param [in&out] while_info
  456. /// @return ComputeGraphPtr
  457. ///
  458. ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
  459. std::string cond_name = while_info.for_body_name + "_Cond";
  460. CompleteGraphBuilder graph_builder(cond_name);
  461. // Add parent node
  462. graph_builder.SetParentNode(while_info.while_node);
  463. // Add Node
  464. const std::string mul_name = "Mul";
  465. graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
  466. const std::string less_name = "Less";
  467. graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
  468. // Set Input
  469. graph_builder.SetInput(kWhileIInputIndex, {mul_name}, {0})
  470. .SetInput(kWhileAbsDeltaInputIndex, {mul_name}, {1})
  471. .SetInput(kWhileRangeInputIndex, {less_name}, {1})
  472. .SetUselessInput(kWhileStartInputIndex)
  473. .SetUselessInput(kWhileDeltaInputIndex);
  474. size_t input_num = while_info.data_inputs.size();
  475. for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
  476. graph_builder.SetUselessInput(i);
  477. }
  478. // Add Output
  479. graph_builder.AddOutput(less_name, 0);
  480. // Add Edges
  481. graph_builder.AddDataLink(mul_name, 0, less_name, 0);
  482. // Add Input-Mapping
  483. std::map<uint32_t, uint32_t> input_mapping;
  484. for (size_t i = 0; i < input_num; i++) {
  485. input_mapping[i] = i;
  486. }
  487. graph_builder.SetInputMapping(input_mapping);
  488. graphStatus error_code = GRAPH_SUCCESS;
  489. std::string error_msg;
  490. ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
  491. if (cond_graph == nullptr) {
  492. GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  493. return nullptr;
  494. }
  495. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  496. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
  497. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
  498. while_info.while_cond = cond_graph;
  499. return cond_graph;
  500. }
  501. ///
  502. /// @brief Build body_graph for while_node
  503. /// @param [in&out] while_info
  504. /// @return ComputeGraphPtr
  505. ///
  506. ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
  507. std::string body_name = while_info.for_body_name + "_Body";
  508. CompleteGraphBuilder graph_builder(body_name);
  509. // Add parent node
  510. graph_builder.SetParentNode(while_info.while_node);
  511. // Add calculation nodes
  512. std::string const_name = "Const";
  513. std::string add_name_0 = "Add_0";
  514. std::string mul_name = "Mul";
  515. std::string add_name_1 = "Add_1";
  516. graph_builder.AddNode(CreateConstDesc(const_name, 1))
  517. .AddNode(CreateOpDesc(add_name_0, ADD, false))
  518. .AddNode(CreateOpDesc(mul_name, MUL, false))
  519. .AddNode(CreateOpDesc(add_name_1, ADD, false));
  520. // Add Subgraph node
  521. auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
  522. std::string sub_graph_node_name = while_info.for_body_name;
  523. uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
  524. auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
  525. graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
  526. // Set Input
  527. graph_builder.SetInput(kWhileIInputIndex, {add_name_0, mul_name}, {0, 0})
  528. .SetUselessInput(kWhileAbsDeltaInputIndex)
  529. .SetUselessInput(kWhileRangeInputIndex)
  530. .SetInput(kWhileStartInputIndex, {add_name_1}, {0})
  531. .SetInput(kWhileDeltaInputIndex, {mul_name}, {1});
  532. for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
  533. graph_builder.SetInput(i + kWhileDataInputIndex, {sub_graph_node_name}, {i + kSubgraphInputIndex});
  534. }
  535. // Add Outputs
  536. graph_builder.AddOutput(add_name_0, 0);
  537. for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
  538. graph_builder.AddOutput("Data_" + std::to_string(i), 0);
  539. }
  540. for (uint32_t i = 0; i < sub_graph_output_num; i++) {
  541. graph_builder.AddOutput(sub_graph_node_name, i);
  542. }
  543. // Add Edges
  544. graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
  545. .AddDataLink(mul_name, 0, add_name_1, 1)
  546. .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
  547. // Add Input-Mapping
  548. std::map<uint32_t, uint32_t> input_mapping;
  549. for (size_t i = 0; i < input_num; i++) {
  550. input_mapping[i] = i;
  551. }
  552. graph_builder.SetInputMapping(input_mapping);
  553. // Add outputMapping
  554. std::map<uint32_t, uint32_t> output_mapping;
  555. for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
  556. output_mapping[i] = i;
  557. }
  558. graph_builder.SetOutputMapping(output_mapping);
  559. graphStatus error_code = GRAPH_SUCCESS;
  560. std::string error_msg;
  561. ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
  562. if (body_graph == nullptr) {
  563. GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  564. return nullptr;
  565. }
  566. NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
  567. if (sub_graph_node == nullptr) {
  568. GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str());
  569. return nullptr;
  570. }
  571. while_info.sub_graph_node = sub_graph_node;
  572. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  573. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
  574. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
  575. while_info.while_body = body_graph;
  576. return body_graph;
  577. }
  578. ///
  579. /// @brief Create op_desc for subgraph node
  580. /// @param [in] name
  581. /// @param [in] input_num
  582. /// @param [in] output_num
  583. /// @return OpDescPtr
  584. ///
  585. OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
  586. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  587. op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num);
  588. OpDescPtr op_desc = op_desc_builder.Build();
  589. if (op_desc == nullptr) {
  590. GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str());
  591. return nullptr;
  592. }
  593. size_t index = op_desc->GetSubgraphInstanceNames().size();
  594. op_desc->AddSubgraphName("f");
  595. op_desc->SetSubgraphInstanceName(index, name);
  596. return op_desc;
  597. }
  598. ///
  599. /// @brief Update InputMapping for for-body-graph
  600. /// @param [in] while_info
  601. /// @return Status
  602. ///
  603. Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
  604. ComputeGraphPtr for_body = while_info.for_body;
  605. GE_CHECK_NOTNULL(for_body);
  606. // index_of_cur_graph_node_input -> index_of_new_graph_node_input
  607. std::map<uint32_t, uint32_t> input_mapping;
  608. size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
  609. for (size_t i = 0; i < input_num; i++) {
  610. if (i == FOR_START_INPUT) {
  611. input_mapping[i] = i;
  612. } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
  613. continue;
  614. } else {
  615. input_mapping[i] = i - 2;
  616. }
  617. }
  618. for_body->UpdateInputMapping(input_mapping);
  619. for_body->SetParentNode(while_info.sub_graph_node);
  620. for_body->SetParentGraph(while_info.while_body);
  621. return SUCCESS;
  622. }
  623. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示