You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

for_pass.cc 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/for_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/op/ge_op_utils.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/common/types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. namespace {
  29. const uint32_t kWhileIInputIndex = 0;
  30. const uint32_t kWhileAbsDeltaInputIndex = 1;
  31. const uint32_t kWhileRangeInputIndex = 2;
  32. const uint32_t kWhileStartInputIndex = 3;
  33. const uint32_t kWhileDeltaInputIndex = 4;
  34. const uint32_t kWhileDataInputIndex = 5;
  35. const uint32_t kSubgraphLoopVarInputIndex = 0;
  36. const uint32_t kSubgraphInputIndex = 1;
  37. const uint32_t kWhileOutputIndex = 5;
  38. const size_t kIDiffValue = 2;
  39. const std::string kAbs = "Abs";
  40. }
  41. namespace ge {
  42. Status ForPass::Run(NodePtr &node) {
  43. if (node->GetType() != FOR) {
  44. GELOGD("no need for_pass for node %s.", node->GetName().c_str());
  45. return SUCCESS;
  46. }
  47. GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
  48. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  49. GE_CHECK_NOTNULL(graph);
  50. ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
  51. GE_CHECK_NOTNULL(root_graph);
  52. ForInfo for_info;
  53. GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info),
  54. "Build ForInfo failed, node:%s.", node->GetName().c_str());
  55. WhileInfo while_info;
  56. GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info),
  57. "Transfer WhileInfo from ForInfo failed, node:%s.", node->GetName().c_str());
  58. ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
  59. if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
  60. REPORT_CALL_ERROR("E19999", "Build cond graph failed or add cond subgraph to root_graph:%s failed",
  61. root_graph->GetName().c_str());
  62. GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str());
  63. return FAILED;
  64. }
  65. ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
  66. if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
  67. REPORT_CALL_ERROR("E19999", "Build body graph failed or add body subgraph to root_graph:%s failed",
  68. root_graph->GetName().c_str());
  69. GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str());
  70. return FAILED;
  71. }
  72. GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info),
  73. "Update InputMapping for for-body-graph failed, node:%s.", node->GetName().c_str());
  74. // for node has and only has one subgraph
  75. GE_CHECK_NOTNULL(node->GetOpDesc());
  76. node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
  77. GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
  78. return IsolateAndDeleteNode(node, std::vector<int>());
  79. }
  80. ///
  81. /// @brief Build for_info
  82. /// @param [in] root_graph
  83. /// @param [in] node
  84. /// @param [out] for_info
  85. /// @return Status
  86. ///
  87. Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
  88. GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
  89. OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
  90. OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
  91. OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
  92. if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
  93. REPORT_INNER_ERROR("E19999", "FOR_START_INPUT index:%d or FOR_LIMIT_INPUT index:%d or FOR_DELTA_INPUT index:%d "
  94. "in data anchor of op:%s(%s) lack, check invalid",
  95. FOR_START_INPUT, FOR_LIMIT_INPUT, FOR_DELTA_INPUT,
  96. node->GetName().c_str(), node->GetType().c_str());
  97. GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str());
  98. return FAILED;
  99. }
  100. std::vector<OutDataAnchorPtr> data_inputs;
  101. std::vector<std::vector<InDataAnchorPtr>> data_outputs;
  102. std::vector<OutControlAnchorPtr> ctrl_inputs;
  103. std::vector<InControlAnchorPtr> ctrl_outputs;
  104. if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
  105. GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str());
  106. return FAILED;
  107. }
  108. NodeUtils::UnlinkAll(*node);
  109. OpDescPtr op_desc = node->GetOpDesc();
  110. GE_CHECK_NOTNULL(op_desc);
  111. // For node has and only has one sub_graph
  112. std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
  113. if (for_body_name.empty()) {
  114. REPORT_INNER_ERROR("E19999", "Get subgraph name from op:%s(%s) by index 0 failed",
  115. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  116. GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str());
  117. return FAILED;
  118. }
  119. ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
  120. if (for_body == nullptr) {
  121. REPORT_INNER_ERROR("E19999", "Get subgraph from graph:%s by name:%s failed",
  122. root_graph->GetName().c_str(), for_body_name.c_str());
  123. GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str());
  124. return FAILED;
  125. }
  126. for_info.for_node = node;
  127. for_info.start = start;
  128. for_info.limit = limit;
  129. for_info.delta = delta;
  130. for_info.body_name = for_body_name;
  131. for_info.for_body = for_body;
  132. for_info.data_inputs = std::move(data_inputs);
  133. for_info.data_outputs = std::move(data_outputs);
  134. for_info.ctrl_inputs = std::move(ctrl_inputs);
  135. for_info.ctrl_outputs = std::move(ctrl_outputs);
  136. GELOGI("Build for_info for node %s success.", node->GetName().c_str());
  137. return SUCCESS;
  138. }
  139. ///
  140. /// @brief Find input with index for For node
  141. /// @param [in] node
  142. /// @param [in] index
  143. /// @return OutDataAnchorPtr
  144. ///
  145. OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
  146. if (node == nullptr) {
  147. GELOGE(FAILED, "FindInputWithIndex failed: node is NULL.");
  148. return nullptr;
  149. }
  150. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  151. if (in_data_anchor == nullptr) {
  152. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  153. return nullptr;
  154. }
  155. return in_data_anchor->GetPeerOutAnchor();
  156. }
  157. ///
  158. /// @brief Find inputs / outputs for for node
  159. /// @param [in] node
  160. /// @param [out] data_inputs
  161. /// @param [out] data_outputs
  162. /// @param [out] ctrl_inputs
  163. /// @param [out] ctrl_outputs
  164. /// @return Status
  165. ///
  166. Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
  167. std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
  168. std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
  169. std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
  170. GE_CHECK_NOTNULL(node);
  171. uint32_t input_data_num = node->GetAllInDataAnchorsSize();
  172. for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
  173. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  174. GE_CHECK_NOTNULL(in_data_anchor);
  175. data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
  176. }
  177. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  178. std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
  179. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  180. peer_in_data_anchors.emplace_back(peer_in_data_anchor);
  181. }
  182. data_outputs.emplace_back(peer_in_data_anchors);
  183. }
  184. InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
  185. GE_CHECK_NOTNULL(in_ctrl_anchor);
  186. for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  187. ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
  188. }
  189. OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
  190. GE_CHECK_NOTNULL(out_ctrl_anchor);
  191. for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  192. ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
  193. }
  194. return SUCCESS;
  195. }
  196. ///
  197. /// @brief Transfer while_info from for_info
  198. /// @param [in] graph
  199. /// @param [in] for_info
  200. /// @param [out] while_info
  201. /// @return Status
  202. ///
  203. Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
  204. std::string for_name = for_info.for_node->GetName();
  205. GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
  206. std::string i_name = for_name + "_i";
  207. NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
  208. if (i_node == nullptr) {
  209. REPORT_CALL_ERROR("E19999", "Add node:%s(Const) to graph:%s failed",
  210. i_name.c_str(), graph->GetName().c_str());
  211. GELOGE(FAILED, "TranWhileInfo failed: create i_node failed.");
  212. return FAILED;
  213. }
  214. AddRePassNode(i_node);
  215. std::string identity_name = i_name + "_Identity";
  216. NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
  217. // Const node has and only has one output, Identity node has and only has one input
  218. if ((identity_node == nullptr) ||
  219. (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
  220. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  221. i_node->GetName().c_str(), i_node->GetType().c_str(),
  222. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  223. GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str());
  224. return FAILED;
  225. }
  226. AddRePassNode(identity_node);
  227. // Identity node has and only has one output
  228. OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
  229. if (i_input == nullptr) {
  230. REPORT_INNER_ERROR("E19999", "Out data anchor index:0 in op:%s(%s) is nullptr, check invalid",
  231. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  232. GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL.");
  233. return FAILED;
  234. }
  235. OutDataAnchorPtr range_input = nullptr;
  236. OutDataAnchorPtr abs_delta_input = nullptr;
  237. if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
  238. GELOGE(FAILED, "TranWhileInfo failed: create loop input failed.");
  239. return FAILED;
  240. }
  241. BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
  242. if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
  243. GELOGE(FAILED, "TranWhileInfo failed: insert while node failed.");
  244. return FAILED;
  245. }
  246. GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.",
  247. for_name.c_str(), while_info.while_node->GetName().c_str());
  248. return SUCCESS;
  249. }
  250. ///
  251. /// @brief Create const op_desc
  252. /// @param [in] name
  253. /// @param [in] value
  254. /// @return OpDescPtr
  255. ///
  256. OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
  257. OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
  258. if (const_op_desc == nullptr) {
  259. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  260. GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str());
  261. return nullptr;
  262. }
  263. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
  264. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  265. if (const_value == nullptr) {
  266. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  267. GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str());
  268. return nullptr;
  269. }
  270. if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  271. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  272. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  273. GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str());
  274. return nullptr;
  275. }
  276. if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
  277. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed, name:y",
  278. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  279. GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str());
  280. return nullptr;
  281. }
  282. return const_op_desc;
  283. }
  284. ///
  285. /// @brief Create loop node
  286. /// @param [in] graph
  287. /// @param [in] for_info
  288. /// @param [out] range_input
  289. /// @param [out] abs_delta_input
  290. /// @return Status
  291. ///
  292. Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info,
  293. OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) {
  294. std::string for_name = for_info.for_node->GetName();
  295. GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
  296. OutDataAnchorPtr start = for_info.start;
  297. OutDataAnchorPtr limit = for_info.limit;
  298. OutDataAnchorPtr delta = for_info.delta;
  299. std::string sub_name_0 = for_name + "_Sub_0";
  300. std::string abs_name_0 = for_name + "_Abs_0";
  301. std::string abs_name_1 = for_name + "_Abs_1";
  302. // i * |delta| < |limit-start|
  303. PartialGraphBuilder graph_builder;
  304. graph_builder.SetOwnerGraph(graph)
  305. .AddExistNode(for_info.start->GetOwnerNode())
  306. .AddExistNode(for_info.limit->GetOwnerNode())
  307. .AddExistNode(for_info.delta->GetOwnerNode())
  308. .AddNode(CreateOpDesc(sub_name_0, SUB, false))
  309. .AddNode(CreateOpDesc(abs_name_0, kAbs, true))
  310. .AddNode(CreateOpDesc(abs_name_1, kAbs, true))
  311. .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
  312. .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
  313. .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
  314. .AddDataLink(sub_name_0, 0, abs_name_1, 0);
  315. graphStatus error_code = GRAPH_SUCCESS;
  316. std::string error_msg;
  317. if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
  318. REPORT_CALL_ERROR("E19999", "Add loop input node to graph:%s failed", graph->GetName().c_str());
  319. GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  320. return FAILED;
  321. }
  322. // Add repass_nodes
  323. for (auto &node : graph_builder.GetAllNodes()) {
  324. AddRePassNode(node);
  325. }
  326. NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
  327. NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
  328. if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
  329. REPORT_CALL_ERROR("E19999", "Add loop input node to graph:%s failed", graph->GetName().c_str());
  330. GELOGE(FAILED, "Create loop node failed: node is NULL.");
  331. return FAILED;
  332. }
  333. GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
  334. // abs_node has and only has one output
  335. abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
  336. range_input = loop_count_node->GetOutDataAnchor(0);
  337. return SUCCESS;
  338. }
  339. ///
  340. /// @brief Create op_desc
  341. /// @param [in] name
  342. /// @param [in] type
  343. /// @param [in] io_equal_flag
  344. /// @return OpDescPtr
  345. ///
  346. OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
  347. OpDescBuilder op_desc_builder(name, type);
  348. if (io_equal_flag) {
  349. op_desc_builder.AddInput("x")
  350. .AddOutput("y");
  351. } else {
  352. op_desc_builder.AddInput("x1")
  353. .AddInput("x2")
  354. .AddOutput("y");
  355. }
  356. return op_desc_builder.Build();
  357. }
  358. ///
  359. /// @brief Build while-info
  360. /// @param [in] for_info
  361. /// @param [in] i_input
  362. /// @param [in] range_input
  363. /// @param [in] abs_delta_input
  364. /// @param [out] while_info
  365. /// @return void
  366. ///
  367. void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
  368. const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
  369. WhileInfo &while_info) {
  370. while_info.i = i_input;
  371. while_info.abs_delta = abs_delta_input;
  372. while_info.range = range_input;
  373. while_info.start = for_info.start;
  374. while_info.delta = for_info.delta;
  375. while_info.for_body_name = for_info.body_name;
  376. while_info.for_body = for_info.for_body;
  377. while_info.data_inputs.emplace_back(while_info.i);
  378. while_info.data_inputs.emplace_back(while_info.abs_delta);
  379. while_info.data_inputs.emplace_back(while_info.range);
  380. while_info.data_inputs.emplace_back(while_info.start);
  381. while_info.data_inputs.emplace_back(while_info.delta);
  382. for (auto &item : for_info.data_inputs) {
  383. while_info.data_inputs.emplace_back(item);
  384. }
  385. for (auto &item : for_info.data_outputs) {
  386. while_info.data_outputs.emplace_back(item);
  387. }
  388. for (auto &item : for_info.ctrl_inputs) {
  389. while_info.ctrl_inputs.emplace_back(item);
  390. }
  391. for (auto &item : for_info.ctrl_outputs) {
  392. while_info.ctrl_outputs.emplace_back(item);
  393. }
  394. }
  395. ///
  396. /// @brief Insert while_node
  397. /// @param [in] graph
  398. /// @param [in] name
  399. /// @param [in&out] while_info
  400. /// @return Status
  401. ///
  402. Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
  403. GELOGD("Begin to create while node, name:%s.", name.c_str());
  404. size_t arg_num = while_info.data_inputs.size();
  405. OpDescBuilder op_desc_builder(name, WHILE);
  406. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
  407. if (op_desc == nullptr) {
  408. REPORT_CALL_ERROR("E19999", "Add dynamic input or output to op:%s(%s) failed",
  409. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  410. GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
  411. return FAILED;
  412. }
  413. NodePtr while_node = graph->AddNode(op_desc);
  414. if (while_node == nullptr) {
  415. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  416. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  417. GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str());
  418. return FAILED;
  419. }
  420. AddRePassNode(while_node);
  421. while_info.while_node = while_node;
  422. if (BuildWhileLink(while_info) != SUCCESS) {
  423. GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str());
  424. return FAILED;
  425. }
  426. GELOGD("Create while node succ, name:%s.", name.c_str());
  427. return SUCCESS;
  428. }
  429. ///
  430. /// @brief Build while link-edge
  431. /// @param [in] while_info
  432. /// @return Status
  433. ///
  434. Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
  435. NodePtr while_node = while_info.while_node;
  436. GE_CHECK_NOTNULL(while_node);
  437. size_t input_num = while_info.data_inputs.size();
  438. for (size_t i = 0; i < input_num; i++) {
  439. InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
  440. GE_CHECK_NOTNULL(in_data_anchor);
  441. OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
  442. if (peer_out_anchor == nullptr) {
  443. continue;
  444. }
  445. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor),
  446. "Add data-edge %s:%d->%s:%zu failed.",
  447. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  448. while_node->GetName().c_str(), i);
  449. }
  450. size_t output_num = while_info.data_outputs.size();
  451. for (size_t i = 0; i < output_num; i++) {
  452. OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
  453. GE_CHECK_NOTNULL(out_data_anchor);
  454. for (auto &peer_in_anchor : while_info.data_outputs[i]) {
  455. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
  456. "Add data-edge %s:%zu->%s:%d failed.",
  457. while_node->GetName().c_str(), i + kWhileOutputIndex,
  458. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  459. }
  460. }
  461. InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
  462. GE_CHECK_NOTNULL(in_ctrl_anchor);
  463. for (auto &peer_out_anchor : while_info.ctrl_inputs) {
  464. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor),
  465. "Add ctrl-edge %s->%s failed.",
  466. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  467. in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
  468. }
  469. OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
  470. GE_CHECK_NOTNULL(out_ctrl_anchor);
  471. for (auto &peer_in_anchor : while_info.ctrl_outputs) {
  472. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor),
  473. "Add ctrl-edge %s->%s failed.",
  474. out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  475. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  476. }
  477. return SUCCESS;
  478. }
  479. ///
  480. /// @brief Build cond_graph for while_node
  481. /// @param [in&out] while_info
  482. /// @return ComputeGraphPtr
  483. ///
  484. ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
  485. std::string cond_name = while_info.for_body_name + "_Cond";
  486. CompleteGraphBuilder graph_builder(cond_name);
  487. // Add parent node
  488. graph_builder.SetParentNode(while_info.while_node);
  489. // Add Node
  490. const std::string mul_name = "Mul";
  491. graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
  492. const std::string less_name = "Less";
  493. graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
  494. // Set Input
  495. graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 })
  496. .SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 })
  497. .SetInput(kWhileRangeInputIndex, { less_name }, { 1 })
  498. .SetUselessInput(kWhileStartInputIndex)
  499. .SetUselessInput(kWhileDeltaInputIndex);
  500. size_t input_num = while_info.data_inputs.size();
  501. for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
  502. graph_builder.SetUselessInput(i);
  503. }
  504. // Add Output
  505. graph_builder.AddOutput(less_name, 0);
  506. // Add Edges
  507. graph_builder.AddDataLink(mul_name, 0, less_name, 0);
  508. // Add Input-Mapping
  509. std::map<uint32_t, uint32_t> input_mapping;
  510. for (size_t i = 0; i < input_num; i++) {
  511. input_mapping[i] = i;
  512. }
  513. graph_builder.SetInputMapping(input_mapping);
  514. graphStatus error_code = GRAPH_SUCCESS;
  515. std::string error_msg;
  516. ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
  517. if (cond_graph == nullptr) {
  518. REPORT_CALL_ERROR("E19999", "Build graph:%s failed", cond_name.c_str());
  519. GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  520. return nullptr;
  521. }
  522. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  523. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
  524. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
  525. while_info.while_cond = cond_graph;
  526. return cond_graph;
  527. }
  528. ///
  529. /// @brief Build body_graph for while_node
  530. /// @param [in&out] while_info
  531. /// @return ComputeGraphPtr
  532. ///
  533. ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
  534. std::string body_name = while_info.for_body_name + "_Body";
  535. CompleteGraphBuilder graph_builder(body_name);
  536. // Add parent node
  537. graph_builder.SetParentNode(while_info.while_node);
  538. // Add calculation nodes
  539. std::string const_name = "Const";
  540. std::string add_name_0 = "Add_0";
  541. std::string mul_name = "Mul";
  542. std::string add_name_1 = "Add_1";
  543. graph_builder.AddNode(CreateConstDesc(const_name, 1))
  544. .AddNode(CreateOpDesc(add_name_0, ADD, false))
  545. .AddNode(CreateOpDesc(mul_name, MUL, false))
  546. .AddNode(CreateOpDesc(add_name_1, ADD, false));
  547. // Add Subgraph node
  548. auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
  549. std::string sub_graph_node_name = while_info.for_body_name;
  550. uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
  551. auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
  552. graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
  553. // Set Input
  554. graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 })
  555. .SetUselessInput(kWhileAbsDeltaInputIndex)
  556. .SetUselessInput(kWhileRangeInputIndex)
  557. .SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 })
  558. .SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 });
  559. for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
  560. graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex });
  561. }
  562. // Add Outputs
  563. graph_builder.AddOutput(add_name_0, 0);
  564. for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
  565. graph_builder.AddOutput("Data_" + std::to_string(i), 0);
  566. }
  567. for (uint32_t i = 0; i < sub_graph_output_num; i++) {
  568. graph_builder.AddOutput(sub_graph_node_name, i);
  569. }
  570. // Add Edges
  571. graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
  572. .AddDataLink(mul_name, 0, add_name_1, 1)
  573. .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
  574. // Add Input-Mapping
  575. std::map<uint32_t, uint32_t> input_mapping;
  576. for (size_t i = 0; i < input_num; i++) {
  577. input_mapping[i] = i;
  578. }
  579. graph_builder.SetInputMapping(input_mapping);
  580. // Add outputMapping
  581. std::map<uint32_t, uint32_t> output_mapping;
  582. for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
  583. output_mapping[i] = i;
  584. }
  585. graph_builder.SetOutputMapping(output_mapping);
  586. graphStatus error_code = GRAPH_SUCCESS;
  587. std::string error_msg;
  588. ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
  589. if (body_graph == nullptr) {
  590. GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  591. return nullptr;
  592. }
  593. NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
  594. if (sub_graph_node == nullptr) {
  595. GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str());
  596. return nullptr;
  597. }
  598. while_info.sub_graph_node = sub_graph_node;
  599. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  600. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
  601. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
  602. while_info.while_body = body_graph;
  603. return body_graph;
  604. }
  605. ///
  606. /// @brief Create op_desc for subgraph node
  607. /// @param [in] name
  608. /// @param [in] input_num
  609. /// @param [in] output_num
  610. /// @return OpDescPtr
  611. ///
  612. OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
  613. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  614. op_desc_builder.AddDynamicInput("args", input_num)
  615. .AddDynamicOutput("output", output_num);
  616. OpDescPtr op_desc = op_desc_builder.Build();
  617. if (op_desc == nullptr) {
  618. REPORT_CALL_ERROR("E19999", "Build op_desc:%s(%s) failed",
  619. name.c_str(), PARTITIONEDCALL);
  620. GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str());
  621. return nullptr;
  622. }
  623. size_t index = op_desc->GetSubgraphInstanceNames().size();
  624. op_desc->AddSubgraphName("f");
  625. op_desc->SetSubgraphInstanceName(index, name);
  626. return op_desc;
  627. }
  628. ///
  629. /// @brief Update InputMapping for for-body-graph
  630. /// @param [in] while_info
  631. /// @return Status
  632. ///
  633. Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
  634. ComputeGraphPtr for_body = while_info.for_body;
  635. GE_CHECK_NOTNULL(for_body);
  636. // index_of_cur_graph_node_input -> index_of_new_graph_node_input
  637. std::map<uint32_t, uint32_t> input_mapping;
  638. size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
  639. for (size_t i = 0; i < input_num; i++) {
  640. if (i == FOR_START_INPUT) {
  641. input_mapping[i] = i;
  642. } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
  643. continue;
  644. } else {
  645. input_mapping[i] = i - kIDiffValue;
  646. }
  647. }
  648. for_body->UpdateInputMapping(input_mapping);
  649. for_body->SetParentNode(while_info.sub_graph_node);
  650. for_body->SetParentGraph(while_info.while_body);
  651. return SUCCESS;
  652. }
  653. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示