You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

for_pass.cc 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/for_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/op/ge_op_utils.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/common/types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. namespace {
  29. const uint32_t kWhileIInputIndex = 0;
  30. const uint32_t kWhileAbsDeltaInputIndex = 1;
  31. const uint32_t kWhileRangeInputIndex = 2;
  32. const uint32_t kWhileStartInputIndex = 3;
  33. const uint32_t kWhileDeltaInputIndex = 4;
  34. const uint32_t kWhileDataInputIndex = 5;
  35. const uint32_t kSubgraphLoopVarInputIndex = 0;
  36. const uint32_t kSubgraphInputIndex = 1;
  37. const uint32_t kWhileOutputIndex = 5;
  38. const size_t kIDiffValue = 2;
  39. const std::string kAbs = "Abs";
  40. }
  41. namespace ge {
  42. Status ForPass::Run(NodePtr &node) {
  43. if (node->GetType() != FOR) {
  44. GELOGD("no need for_pass for node %s.", node->GetName().c_str());
  45. return SUCCESS;
  46. }
  47. GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
  48. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  49. GE_CHECK_NOTNULL(graph);
  50. ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
  51. GE_CHECK_NOTNULL(root_graph);
  52. ForInfo for_info;
  53. GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info),
  54. "Build ForInfo failed, node:%s.", node->GetName().c_str());
  55. WhileInfo while_info;
  56. GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info),
  57. "Transfer WhileInfo from ForInfo failed, node:%s.", node->GetName().c_str());
  58. ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
  59. if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
  60. GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str());
  61. return FAILED;
  62. }
  63. ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
  64. if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
  65. GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str());
  66. return FAILED;
  67. }
  68. GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info),
  69. "Update InputMapping for for-body-graph failed, node:%s.", node->GetName().c_str());
  70. // for node has and only has one subgraph
  71. GE_CHECK_NOTNULL(node->GetOpDesc());
  72. node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
  73. GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
  74. return IsolateAndDeleteNode(node, std::vector<int>());
  75. }
  76. ///
  77. /// @brief Build for_info
  78. /// @param [in] root_graph
  79. /// @param [in] node
  80. /// @param [out] for_info
  81. /// @return Status
  82. ///
  83. Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
  84. GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
  85. OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
  86. OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
  87. OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
  88. if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
  89. GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str());
  90. return FAILED;
  91. }
  92. std::vector<OutDataAnchorPtr> data_inputs;
  93. std::vector<std::vector<InDataAnchorPtr>> data_outputs;
  94. std::vector<OutControlAnchorPtr> ctrl_inputs;
  95. std::vector<InControlAnchorPtr> ctrl_outputs;
  96. if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
  97. GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str());
  98. return FAILED;
  99. }
  100. NodeUtils::UnlinkAll(*node);
  101. OpDescPtr op_desc = node->GetOpDesc();
  102. GE_CHECK_NOTNULL(op_desc);
  103. // For node has and only has one sub_graph
  104. std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
  105. if (for_body_name.empty()) {
  106. GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str());
  107. return FAILED;
  108. }
  109. ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
  110. if (for_body == nullptr) {
  111. GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str());
  112. return FAILED;
  113. }
  114. for_info.for_node = node;
  115. for_info.start = start;
  116. for_info.limit = limit;
  117. for_info.delta = delta;
  118. for_info.body_name = for_body_name;
  119. for_info.for_body = for_body;
  120. for_info.data_inputs = std::move(data_inputs);
  121. for_info.data_outputs = std::move(data_outputs);
  122. for_info.ctrl_inputs = std::move(ctrl_inputs);
  123. for_info.ctrl_outputs = std::move(ctrl_outputs);
  124. GELOGI("Build for_info for node %s success.", node->GetName().c_str());
  125. return SUCCESS;
  126. }
  127. ///
  128. /// @brief Find input with index for For node
  129. /// @param [in] node
  130. /// @param [in] index
  131. /// @return OutDataAnchorPtr
  132. ///
  133. OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
  134. if (node == nullptr) {
  135. GELOGE(FAILED, "FindInputWithIndex failed: node is NULL.");
  136. return nullptr;
  137. }
  138. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  139. if (in_data_anchor == nullptr) {
  140. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  141. return nullptr;
  142. }
  143. return in_data_anchor->GetPeerOutAnchor();
  144. }
  145. ///
  146. /// @brief Find inputs / outputs for for node
  147. /// @param [in] node
  148. /// @param [out] data_inputs
  149. /// @param [out] data_outputs
  150. /// @param [out] ctrl_inputs
  151. /// @param [out] ctrl_outputs
  152. /// @return Status
  153. ///
  154. Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
  155. std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
  156. std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
  157. std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
  158. GE_CHECK_NOTNULL(node);
  159. uint32_t input_data_num = node->GetAllInDataAnchorsSize();
  160. for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
  161. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  162. GE_CHECK_NOTNULL(in_data_anchor);
  163. data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
  164. }
  165. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  166. std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
  167. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  168. peer_in_data_anchors.emplace_back(peer_in_data_anchor);
  169. }
  170. data_outputs.emplace_back(peer_in_data_anchors);
  171. }
  172. InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
  173. GE_CHECK_NOTNULL(in_ctrl_anchor);
  174. for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  175. ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
  176. }
  177. OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
  178. GE_CHECK_NOTNULL(out_ctrl_anchor);
  179. for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  180. ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
  181. }
  182. return SUCCESS;
  183. }
  184. ///
  185. /// @brief Transfer while_info from for_info
  186. /// @param [in] graph
  187. /// @param [in] for_info
  188. /// @param [out] while_info
  189. /// @return Status
  190. ///
  191. Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
  192. std::string for_name = for_info.for_node->GetName();
  193. GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
  194. std::string i_name = for_name + "_i";
  195. NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
  196. if (i_node == nullptr) {
  197. GELOGE(FAILED, "TranWhileInfo failed: create i_node failed.");
  198. return FAILED;
  199. }
  200. AddRePassNode(i_node);
  201. std::string identity_name = i_name + "_Identity";
  202. NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
  203. // Const node has and only has one output, Identity node has and only has one input
  204. if ((identity_node == nullptr) ||
  205. (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
  206. GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str());
  207. return FAILED;
  208. }
  209. AddRePassNode(identity_node);
  210. // Identity node has and only has one output
  211. OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
  212. if (i_input == nullptr) {
  213. GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL.");
  214. return FAILED;
  215. }
  216. OutDataAnchorPtr range_input = nullptr;
  217. OutDataAnchorPtr abs_delta_input = nullptr;
  218. if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
  219. GELOGE(FAILED, "TranWhileInfo failed: create loop input failed.");
  220. return FAILED;
  221. }
  222. BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
  223. if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
  224. GELOGE(FAILED, "TranWhileInfo failed: insert while node failed.");
  225. return FAILED;
  226. }
  227. GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.",
  228. for_name.c_str(), while_info.while_node->GetName().c_str());
  229. return SUCCESS;
  230. }
  231. ///
  232. /// @brief Create const op_desc
  233. /// @param [in] name
  234. /// @param [in] value
  235. /// @return OpDescPtr
  236. ///
  237. OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
  238. OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
  239. if (const_op_desc == nullptr) {
  240. GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str());
  241. return nullptr;
  242. }
  243. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
  244. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  245. if (const_value == nullptr) {
  246. GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str());
  247. return nullptr;
  248. }
  249. if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  250. GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str());
  251. return nullptr;
  252. }
  253. if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
  254. GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str());
  255. return nullptr;
  256. }
  257. return const_op_desc;
  258. }
  259. ///
  260. /// @brief Create loop node
  261. /// @param [in] graph
  262. /// @param [in] for_info
  263. /// @param [out] range_input
  264. /// @param [out] abs_delta_input
  265. /// @return Status
  266. ///
  267. Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info,
  268. OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) {
  269. std::string for_name = for_info.for_node->GetName();
  270. GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
  271. OutDataAnchorPtr start = for_info.start;
  272. OutDataAnchorPtr limit = for_info.limit;
  273. OutDataAnchorPtr delta = for_info.delta;
  274. std::string sub_name_0 = for_name + "_Sub_0";
  275. std::string abs_name_0 = for_name + "_Abs_0";
  276. std::string abs_name_1 = for_name + "_Abs_1";
  277. // i * |delta| < |limit-start|
  278. PartialGraphBuilder graph_builder;
  279. graph_builder.SetOwnerGraph(graph)
  280. .AddExistNode(for_info.start->GetOwnerNode())
  281. .AddExistNode(for_info.limit->GetOwnerNode())
  282. .AddExistNode(for_info.delta->GetOwnerNode())
  283. .AddNode(CreateOpDesc(sub_name_0, SUB, false))
  284. .AddNode(CreateOpDesc(abs_name_0, kAbs, true))
  285. .AddNode(CreateOpDesc(abs_name_1, kAbs, true))
  286. .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
  287. .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
  288. .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
  289. .AddDataLink(sub_name_0, 0, abs_name_1, 0);
  290. graphStatus error_code = GRAPH_SUCCESS;
  291. std::string error_msg;
  292. if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
  293. GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  294. return FAILED;
  295. }
  296. // Add repass_nodes
  297. for (auto &node : graph_builder.GetAllNodes()) {
  298. AddRePassNode(node);
  299. }
  300. NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
  301. NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
  302. if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
  303. GELOGE(FAILED, "Create loop node failed: node is NULL.");
  304. return FAILED;
  305. }
  306. GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
  307. // abs_node has and only has one output
  308. abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
  309. range_input = loop_count_node->GetOutDataAnchor(0);
  310. return SUCCESS;
  311. }
  312. ///
  313. /// @brief Create op_desc
  314. /// @param [in] name
  315. /// @param [in] type
  316. /// @param [in] io_equal_flag
  317. /// @return OpDescPtr
  318. ///
  319. OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
  320. OpDescBuilder op_desc_builder(name, type);
  321. if (io_equal_flag) {
  322. op_desc_builder.AddInput("x")
  323. .AddOutput("y");
  324. } else {
  325. op_desc_builder.AddInput("x1")
  326. .AddInput("x2")
  327. .AddOutput("y");
  328. }
  329. return op_desc_builder.Build();
  330. }
  331. ///
  332. /// @brief Build while-info
  333. /// @param [in] for_info
  334. /// @param [in] i_input
  335. /// @param [in] range_input
  336. /// @param [in] abs_delta_input
  337. /// @param [out] while_info
  338. /// @return void
  339. ///
  340. void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
  341. const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
  342. WhileInfo &while_info) {
  343. while_info.i = i_input;
  344. while_info.abs_delta = abs_delta_input;
  345. while_info.range = range_input;
  346. while_info.start = for_info.start;
  347. while_info.delta = for_info.delta;
  348. while_info.for_body_name = for_info.body_name;
  349. while_info.for_body = for_info.for_body;
  350. while_info.data_inputs.emplace_back(while_info.i);
  351. while_info.data_inputs.emplace_back(while_info.abs_delta);
  352. while_info.data_inputs.emplace_back(while_info.range);
  353. while_info.data_inputs.emplace_back(while_info.start);
  354. while_info.data_inputs.emplace_back(while_info.delta);
  355. for (auto &item : for_info.data_inputs) {
  356. while_info.data_inputs.emplace_back(item);
  357. }
  358. for (auto &item : for_info.data_outputs) {
  359. while_info.data_outputs.emplace_back(item);
  360. }
  361. for (auto &item : for_info.ctrl_inputs) {
  362. while_info.ctrl_inputs.emplace_back(item);
  363. }
  364. for (auto &item : for_info.ctrl_outputs) {
  365. while_info.ctrl_outputs.emplace_back(item);
  366. }
  367. }
  368. ///
  369. /// @brief Insert while_node
  370. /// @param [in] graph
  371. /// @param [in] name
  372. /// @param [in&out] while_info
  373. /// @return Status
  374. ///
  375. Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
  376. GELOGD("Begin to create while node, name:%s.", name.c_str());
  377. size_t arg_num = while_info.data_inputs.size();
  378. OpDescBuilder op_desc_builder(name, WHILE);
  379. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
  380. if (op_desc == nullptr) {
  381. GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
  382. return FAILED;
  383. }
  384. NodePtr while_node = graph->AddNode(op_desc);
  385. if (while_node == nullptr) {
  386. GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str());
  387. return FAILED;
  388. }
  389. AddRePassNode(while_node);
  390. while_info.while_node = while_node;
  391. if (BuildWhileLink(while_info) != SUCCESS) {
  392. GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str());
  393. return FAILED;
  394. }
  395. GELOGD("Create while node succ, name:%s.", name.c_str());
  396. return SUCCESS;
  397. }
  398. ///
  399. /// @brief Build while link-edge
  400. /// @param [in] while_info
  401. /// @return Status
  402. ///
  403. Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
  404. NodePtr while_node = while_info.while_node;
  405. GE_CHECK_NOTNULL(while_node);
  406. size_t input_num = while_info.data_inputs.size();
  407. for (size_t i = 0; i < input_num; i++) {
  408. InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
  409. GE_CHECK_NOTNULL(in_data_anchor);
  410. OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
  411. if (peer_out_anchor == nullptr) {
  412. continue;
  413. }
  414. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor),
  415. "Add data-edge %s:%d->%s:%d failed.",
  416. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  417. while_node->GetName().c_str(), i);
  418. }
  419. size_t output_num = while_info.data_outputs.size();
  420. for (size_t i = 0; i < output_num; i++) {
  421. OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
  422. GE_CHECK_NOTNULL(out_data_anchor);
  423. for (auto &peer_in_anchor : while_info.data_outputs[i]) {
  424. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
  425. "Add data-edge %s:%d->%s:%d failed.",
  426. while_node->GetName().c_str(), i + kWhileOutputIndex,
  427. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  428. }
  429. }
  430. InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
  431. GE_CHECK_NOTNULL(in_ctrl_anchor);
  432. for (auto &peer_out_anchor : while_info.ctrl_inputs) {
  433. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor),
  434. "Add ctrl-edge %s->%s failed.",
  435. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  436. in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
  437. }
  438. OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
  439. GE_CHECK_NOTNULL(out_ctrl_anchor);
  440. for (auto &peer_in_anchor : while_info.ctrl_outputs) {
  441. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor),
  442. "Add ctrl-edge %s->%s failed.",
  443. out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  444. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  445. }
  446. return SUCCESS;
  447. }
  448. ///
  449. /// @brief Build cond_graph for while_node
  450. /// @param [in&out] while_info
  451. /// @return ComputeGraphPtr
  452. ///
  453. ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
  454. std::string cond_name = while_info.for_body_name + "_Cond";
  455. CompleteGraphBuilder graph_builder(cond_name);
  456. // Add parent node
  457. graph_builder.SetParentNode(while_info.while_node);
  458. // Add Node
  459. const std::string mul_name = "Mul";
  460. graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
  461. const std::string less_name = "Less";
  462. graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
  463. // Set Input
  464. graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 })
  465. .SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 })
  466. .SetInput(kWhileRangeInputIndex, { less_name }, { 1 })
  467. .SetUselessInput(kWhileStartInputIndex)
  468. .SetUselessInput(kWhileDeltaInputIndex);
  469. size_t input_num = while_info.data_inputs.size();
  470. for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
  471. graph_builder.SetUselessInput(i);
  472. }
  473. // Add Output
  474. graph_builder.AddOutput(less_name, 0);
  475. // Add Edges
  476. graph_builder.AddDataLink(mul_name, 0, less_name, 0);
  477. // Add Input-Mapping
  478. std::map<uint32_t, uint32_t> input_mapping;
  479. for (size_t i = 0; i < input_num; i++) {
  480. input_mapping[i] = i;
  481. }
  482. graph_builder.SetInputMapping(input_mapping);
  483. graphStatus error_code = GRAPH_SUCCESS;
  484. std::string error_msg;
  485. ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
  486. if (cond_graph == nullptr) {
  487. GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  488. return nullptr;
  489. }
  490. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  491. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
  492. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
  493. while_info.while_cond = cond_graph;
  494. return cond_graph;
  495. }
  496. ///
  497. /// @brief Build body_graph for while_node
  498. /// @param [in&out] while_info
  499. /// @return ComputeGraphPtr
  500. ///
  501. ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
  502. std::string body_name = while_info.for_body_name + "_Body";
  503. CompleteGraphBuilder graph_builder(body_name);
  504. // Add parent node
  505. graph_builder.SetParentNode(while_info.while_node);
  506. // Add calculation nodes
  507. std::string const_name = "Const";
  508. std::string add_name_0 = "Add_0";
  509. std::string mul_name = "Mul";
  510. std::string add_name_1 = "Add_1";
  511. graph_builder.AddNode(CreateConstDesc(const_name, 1))
  512. .AddNode(CreateOpDesc(add_name_0, ADD, false))
  513. .AddNode(CreateOpDesc(mul_name, MUL, false))
  514. .AddNode(CreateOpDesc(add_name_1, ADD, false));
  515. // Add Subgraph node
  516. auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
  517. std::string sub_graph_node_name = while_info.for_body_name;
  518. uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
  519. auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
  520. graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
  521. // Set Input
  522. graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 })
  523. .SetUselessInput(kWhileAbsDeltaInputIndex)
  524. .SetUselessInput(kWhileRangeInputIndex)
  525. .SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 })
  526. .SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 });
  527. for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
  528. graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex });
  529. }
  530. // Add Outputs
  531. graph_builder.AddOutput(add_name_0, 0);
  532. for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
  533. graph_builder.AddOutput("Data_" + std::to_string(i), 0);
  534. }
  535. for (uint32_t i = 0; i < sub_graph_output_num; i++) {
  536. graph_builder.AddOutput(sub_graph_node_name, i);
  537. }
  538. // Add Edges
  539. graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
  540. .AddDataLink(mul_name, 0, add_name_1, 1)
  541. .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
  542. // Add Input-Mapping
  543. std::map<uint32_t, uint32_t> input_mapping;
  544. for (size_t i = 0; i < input_num; i++) {
  545. input_mapping[i] = i;
  546. }
  547. graph_builder.SetInputMapping(input_mapping);
  548. // Add outputMapping
  549. std::map<uint32_t, uint32_t> output_mapping;
  550. for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
  551. output_mapping[i] = i;
  552. }
  553. graph_builder.SetOutputMapping(output_mapping);
  554. graphStatus error_code = GRAPH_SUCCESS;
  555. std::string error_msg;
  556. ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
  557. if (body_graph == nullptr) {
  558. GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  559. return nullptr;
  560. }
  561. NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
  562. if (sub_graph_node == nullptr) {
  563. GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str());
  564. return nullptr;
  565. }
  566. while_info.sub_graph_node = sub_graph_node;
  567. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  568. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
  569. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
  570. while_info.while_body = body_graph;
  571. return body_graph;
  572. }
  573. ///
  574. /// @brief Create op_desc for subgraph node
  575. /// @param [in] name
  576. /// @param [in] input_num
  577. /// @param [in] output_num
  578. /// @return OpDescPtr
  579. ///
  580. OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
  581. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  582. op_desc_builder.AddDynamicInput("args", input_num)
  583. .AddDynamicOutput("output", output_num);
  584. OpDescPtr op_desc = op_desc_builder.Build();
  585. if (op_desc == nullptr) {
  586. GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str());
  587. return nullptr;
  588. }
  589. size_t index = op_desc->GetSubgraphInstanceNames().size();
  590. op_desc->AddSubgraphName("f");
  591. op_desc->SetSubgraphInstanceName(index, name);
  592. return op_desc;
  593. }
  594. ///
  595. /// @brief Update InputMapping for for-body-graph
  596. /// @param [in] while_info
  597. /// @return Status
  598. ///
  599. Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
  600. ComputeGraphPtr for_body = while_info.for_body;
  601. GE_CHECK_NOTNULL(for_body);
  602. // index_of_cur_graph_node_input -> index_of_new_graph_node_input
  603. std::map<uint32_t, uint32_t> input_mapping;
  604. size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
  605. for (size_t i = 0; i < input_num; i++) {
  606. if (i == FOR_START_INPUT) {
  607. input_mapping[i] = i;
  608. } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
  609. continue;
  610. } else {
  611. input_mapping[i] = i - kIDiffValue;
  612. }
  613. }
  614. for_body->UpdateInputMapping(input_mapping);
  615. for_body->SetParentNode(while_info.sub_graph_node);
  616. for_body->SetParentGraph(while_info.while_body);
  617. return SUCCESS;
  618. }
  619. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示