You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

for_pass.cc 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/for_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/op/ge_op_utils.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/common/types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. namespace {
  29. const uint32_t kWhileIInputIndex = 0;
  30. const uint32_t kWhileAbsDeltaInputIndex = 1;
  31. const uint32_t kWhileRangeInputIndex = 2;
  32. const uint32_t kWhileStartInputIndex = 3;
  33. const uint32_t kWhileDeltaInputIndex = 4;
  34. const uint32_t kWhileDataInputIndex = 5;
  35. const uint32_t kSubgraphLoopVarInputIndex = 0;
  36. const uint32_t kSubgraphInputIndex = 1;
  37. const uint32_t kWhileOutputIndex = 5;
  38. const size_t kIDiffValue = 2;
  39. const std::string kAbs = "Abs";
  40. }
  41. namespace ge {
  42. Status ForPass::Run(NodePtr &node) {
  43. if (node->GetType() != FOR) {
  44. GELOGD("no need for_pass for node %s.", node->GetName().c_str());
  45. return SUCCESS;
  46. }
  47. GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
  48. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  49. GE_CHECK_NOTNULL(graph);
  50. ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
  51. GE_CHECK_NOTNULL(root_graph);
  52. ForInfo for_info;
  53. GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info),
  54. "[Build][ForInfo] failed, node:%s.", node->GetName().c_str());
  55. WhileInfo while_info;
  56. GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info),
  57. "[Transfer][WhileInfo] from ForInfo failed, node:%s.", node->GetName().c_str());
  58. ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
  59. if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
  60. REPORT_CALL_ERROR("E19999", "Build cond graph failed or add cond subgraph to root_graph:%s failed",
  61. root_graph->GetName().c_str());
  62. GELOGE(FAILED, "[Check][Param] Build cond graph failed or add cond subgraph to root_graph:%s failed.",
  63. root_graph->GetName().c_str());
  64. return FAILED;
  65. }
  66. ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
  67. if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
  68. REPORT_CALL_ERROR("E19999", "Build body graph failed or add body subgraph to root_graph:%s failed",
  69. root_graph->GetName().c_str());
  70. GELOGE(FAILED, "[Check][Param] Build body graph failed or add body subgraph to root_graph:%s failed",
  71. root_graph->GetName().c_str());
  72. return FAILED;
  73. }
  74. GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info),
  75. "[Update][InputMapping] for for-body-graph failed, node:%s.", node->GetName().c_str());
  76. // for node has and only has one subgraph
  77. GE_CHECK_NOTNULL(node->GetOpDesc());
  78. node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
  79. GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
  80. return IsolateAndDeleteNode(node, std::vector<int>());
  81. }
  82. ///
  83. /// @brief Build for_info
  84. /// @param [in] root_graph
  85. /// @param [in] node
  86. /// @param [out] for_info
  87. /// @return Status
  88. ///
  89. Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
  90. GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
  91. OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
  92. OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
  93. OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
  94. if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
  95. REPORT_INNER_ERROR("E19999", "FOR_START_INPUT index:%d or FOR_LIMIT_INPUT index:%d or FOR_DELTA_INPUT index:%d "
  96. "in data anchor of op:%s(%s) lack, check invalid",
  97. FOR_START_INPUT, FOR_LIMIT_INPUT, FOR_DELTA_INPUT,
  98. node->GetName().c_str(), node->GetType().c_str());
  99. GELOGE(FAILED, "[Check][Param] FOR_START_INPUT index:%d or FOR_LIMIT_INPUT index:%d or FOR_DELTA_INPUT index:%d "
  100. "in data anchor of op:%s(%s) lack",
  101. FOR_START_INPUT, FOR_LIMIT_INPUT, FOR_DELTA_INPUT, node->GetName().c_str(), node->GetType().c_str());
  102. return FAILED;
  103. }
  104. std::vector<OutDataAnchorPtr> data_inputs;
  105. std::vector<std::vector<InDataAnchorPtr>> data_outputs;
  106. std::vector<OutControlAnchorPtr> ctrl_inputs;
  107. std::vector<InControlAnchorPtr> ctrl_outputs;
  108. if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
  109. GELOGE(FAILED, "[Find][InputsAndOutputs] in node:%s failed.", node->GetName().c_str());
  110. return FAILED;
  111. }
  112. NodeUtils::UnlinkAll(*node);
  113. OpDescPtr op_desc = node->GetOpDesc();
  114. GE_CHECK_NOTNULL(op_desc);
  115. // For node has and only has one sub_graph
  116. std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
  117. if (for_body_name.empty()) {
  118. REPORT_INNER_ERROR("E19999", "Get subgraph name from op:%s(%s) by index 0 failed",
  119. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  120. GELOGE(FAILED, "[Get][SubGraphName] from op:%s(%s) by index 0 failed",
  121. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  122. return FAILED;
  123. }
  124. ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
  125. if (for_body == nullptr) {
  126. REPORT_INNER_ERROR("E19999", "Get subgraph from graph:%s by name:%s failed",
  127. root_graph->GetName().c_str(), for_body_name.c_str());
  128. GELOGE(FAILED, "[Get][SubGraph] from graph:%s by name:%s failed",
  129. root_graph->GetName().c_str(), for_body_name.c_str());
  130. return FAILED;
  131. }
  132. for_info.for_node = node;
  133. for_info.start = start;
  134. for_info.limit = limit;
  135. for_info.delta = delta;
  136. for_info.body_name = for_body_name;
  137. for_info.for_body = for_body;
  138. for_info.data_inputs = std::move(data_inputs);
  139. for_info.data_outputs = std::move(data_outputs);
  140. for_info.ctrl_inputs = std::move(ctrl_inputs);
  141. for_info.ctrl_outputs = std::move(ctrl_outputs);
  142. GELOGI("Build for_info for node %s success.", node->GetName().c_str());
  143. return SUCCESS;
  144. }
  145. ///
  146. /// @brief Find input with index for For node
  147. /// @param [in] node
  148. /// @param [in] index
  149. /// @return OutDataAnchorPtr
  150. ///
  151. OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
  152. if (node == nullptr) {
  153. GELOGE(FAILED, "[Check][Param] node is nullptr.");
  154. return nullptr;
  155. }
  156. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  157. if (in_data_anchor == nullptr) {
  158. GELOGE(FAILED, "[Get][InDataAnchor] failed, In Data Anchor index:%u in node:%s is nullptr.",
  159. index, node->GetName().c_str());
  160. return nullptr;
  161. }
  162. return in_data_anchor->GetPeerOutAnchor();
  163. }
  164. ///
  165. /// @brief Find inputs / outputs for for node
  166. /// @param [in] node
  167. /// @param [out] data_inputs
  168. /// @param [out] data_outputs
  169. /// @param [out] ctrl_inputs
  170. /// @param [out] ctrl_outputs
  171. /// @return Status
  172. ///
  173. Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
  174. std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
  175. std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
  176. std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
  177. GE_CHECK_NOTNULL(node);
  178. uint32_t input_data_num = node->GetAllInDataAnchorsSize();
  179. for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
  180. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  181. GE_CHECK_NOTNULL(in_data_anchor);
  182. data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
  183. }
  184. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  185. std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
  186. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  187. peer_in_data_anchors.emplace_back(peer_in_data_anchor);
  188. }
  189. data_outputs.emplace_back(peer_in_data_anchors);
  190. }
  191. InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
  192. GE_CHECK_NOTNULL(in_ctrl_anchor);
  193. for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  194. ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
  195. }
  196. OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
  197. GE_CHECK_NOTNULL(out_ctrl_anchor);
  198. for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  199. ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
  200. }
  201. return SUCCESS;
  202. }
  203. ///
  204. /// @brief Transfer while_info from for_info
  205. /// @param [in] graph
  206. /// @param [in] for_info
  207. /// @param [out] while_info
  208. /// @return Status
  209. ///
  210. Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
  211. std::string for_name = for_info.for_node->GetName();
  212. GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
  213. std::string i_name = for_name + "_i";
  214. NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
  215. if (i_node == nullptr) {
  216. REPORT_CALL_ERROR("E19999", "Add node:%s(Const) to graph:%s failed", i_name.c_str(), graph->GetName().c_str());
  217. GELOGE(FAILED, "[Add][Node] %s(Const) to graph:%s failed", i_name.c_str(), graph->GetName().c_str());
  218. return FAILED;
  219. }
  220. AddRePassNode(i_node);
  221. std::string identity_name = i_name + "_Identity";
  222. NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
  223. // Const node has and only has one output, Identity node has and only has one input
  224. if ((identity_node == nullptr) ||
  225. (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
  226. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  227. i_node->GetName().c_str(), i_node->GetType().c_str(),
  228. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  229. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  230. i_node->GetName().c_str(), i_node->GetType().c_str(),
  231. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  232. return FAILED;
  233. }
  234. AddRePassNode(identity_node);
  235. // Identity node has and only has one output
  236. OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
  237. if (i_input == nullptr) {
  238. REPORT_INNER_ERROR("E19999", "Out data anchor index:0 in op:%s(%s) is nullptr, check invalid",
  239. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  240. GELOGE(FAILED, "[Get][OutDataAnchor] failed, Out data anchor index:0 in op:%s(%s) is nullptr",
  241. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  242. return FAILED;
  243. }
  244. OutDataAnchorPtr range_input = nullptr;
  245. OutDataAnchorPtr abs_delta_input = nullptr;
  246. if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
  247. GELOGE(FAILED, "[Create][LoopInput] failed, graph:%s.", graph->GetName().c_str());
  248. return FAILED;
  249. }
  250. BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
  251. if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
  252. GELOGE(FAILED, "[Insert][WhileNode] in graph:%s failed.", graph->GetName().c_str());
  253. return FAILED;
  254. }
  255. GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.",
  256. for_name.c_str(), while_info.while_node->GetName().c_str());
  257. return SUCCESS;
  258. }
  259. ///
  260. /// @brief Create const op_desc
  261. /// @param [in] name
  262. /// @param [in] value
  263. /// @return OpDescPtr
  264. ///
  265. OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
  266. OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
  267. if (const_op_desc == nullptr) {
  268. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  269. GELOGE(FAILED, "[New][OpDesc] failed.");
  270. return nullptr;
  271. }
  272. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
  273. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  274. if (const_value == nullptr) {
  275. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  276. GELOGE(FAILED, "[New][GeTensor] failed");
  277. return nullptr;
  278. }
  279. if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  280. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  281. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  282. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  283. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  284. return nullptr;
  285. }
  286. if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
  287. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed, name:y",
  288. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  289. GELOGE(FAILED, "[Add][OutputDesc] to op:%s(%s) failed, name:y",
  290. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  291. return nullptr;
  292. }
  293. return const_op_desc;
  294. }
  295. ///
  296. /// @brief Create loop node
  297. /// @param [in] graph
  298. /// @param [in] for_info
  299. /// @param [out] range_input
  300. /// @param [out] abs_delta_input
  301. /// @return Status
  302. ///
  303. Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info,
  304. OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) {
  305. std::string for_name = for_info.for_node->GetName();
  306. GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
  307. OutDataAnchorPtr start = for_info.start;
  308. OutDataAnchorPtr limit = for_info.limit;
  309. OutDataAnchorPtr delta = for_info.delta;
  310. std::string sub_name_0 = for_name + "_Sub_0";
  311. std::string abs_name_0 = for_name + "_Abs_0";
  312. std::string abs_name_1 = for_name + "_Abs_1";
  313. // i * |delta| < |limit-start|
  314. PartialGraphBuilder graph_builder;
  315. graph_builder.SetOwnerGraph(graph)
  316. .AddExistNode(for_info.start->GetOwnerNode())
  317. .AddExistNode(for_info.limit->GetOwnerNode())
  318. .AddExistNode(for_info.delta->GetOwnerNode())
  319. .AddNode(CreateOpDesc(sub_name_0, SUB, false))
  320. .AddNode(CreateOpDesc(abs_name_0, kAbs, true))
  321. .AddNode(CreateOpDesc(abs_name_1, kAbs, true))
  322. .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
  323. .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
  324. .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
  325. .AddDataLink(sub_name_0, 0, abs_name_1, 0);
  326. graphStatus error_code = GRAPH_SUCCESS;
  327. std::string error_msg;
  328. if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
  329. REPORT_CALL_ERROR("E19999", "Add loop input node to graph:%s failed", graph->GetName().c_str());
  330. GELOGE(FAILED, "[Create][LoopInputNode] failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  331. return FAILED;
  332. }
  333. // Add repass_nodes
  334. for (auto &node : graph_builder.GetAllNodes()) {
  335. AddRePassNode(node);
  336. }
  337. NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
  338. NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
  339. if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
  340. REPORT_CALL_ERROR("E19999", "Add loop input node to graph:%s failed", graph->GetName().c_str());
  341. GELOGE(FAILED, "[Create][LoopNode] failed: node is nullptr, graph:%s.", graph->GetName().c_str());
  342. return FAILED;
  343. }
  344. GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
  345. // abs_node has and only has one output
  346. abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
  347. range_input = loop_count_node->GetOutDataAnchor(0);
  348. return SUCCESS;
  349. }
  350. ///
  351. /// @brief Create op_desc
  352. /// @param [in] name
  353. /// @param [in] type
  354. /// @param [in] io_equal_flag
  355. /// @return OpDescPtr
  356. ///
  357. OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
  358. OpDescBuilder op_desc_builder(name, type);
  359. if (io_equal_flag) {
  360. op_desc_builder.AddInput("x")
  361. .AddOutput("y");
  362. } else {
  363. op_desc_builder.AddInput("x1")
  364. .AddInput("x2")
  365. .AddOutput("y");
  366. }
  367. return op_desc_builder.Build();
  368. }
  369. ///
  370. /// @brief Build while-info
  371. /// @param [in] for_info
  372. /// @param [in] i_input
  373. /// @param [in] range_input
  374. /// @param [in] abs_delta_input
  375. /// @param [out] while_info
  376. /// @return void
  377. ///
  378. void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
  379. const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
  380. WhileInfo &while_info) {
  381. while_info.i = i_input;
  382. while_info.abs_delta = abs_delta_input;
  383. while_info.range = range_input;
  384. while_info.start = for_info.start;
  385. while_info.delta = for_info.delta;
  386. while_info.for_body_name = for_info.body_name;
  387. while_info.for_body = for_info.for_body;
  388. while_info.data_inputs.emplace_back(while_info.i);
  389. while_info.data_inputs.emplace_back(while_info.abs_delta);
  390. while_info.data_inputs.emplace_back(while_info.range);
  391. while_info.data_inputs.emplace_back(while_info.start);
  392. while_info.data_inputs.emplace_back(while_info.delta);
  393. for (auto &item : for_info.data_inputs) {
  394. while_info.data_inputs.emplace_back(item);
  395. }
  396. for (auto &item : for_info.data_outputs) {
  397. while_info.data_outputs.emplace_back(item);
  398. }
  399. for (auto &item : for_info.ctrl_inputs) {
  400. while_info.ctrl_inputs.emplace_back(item);
  401. }
  402. for (auto &item : for_info.ctrl_outputs) {
  403. while_info.ctrl_outputs.emplace_back(item);
  404. }
  405. }
  406. ///
  407. /// @brief Insert while_node
  408. /// @param [in] graph
  409. /// @param [in] name
  410. /// @param [in&out] while_info
  411. /// @return Status
  412. ///
  413. Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
  414. GELOGD("Begin to create while node, name:%s.", name.c_str());
  415. size_t arg_num = while_info.data_inputs.size();
  416. OpDescBuilder op_desc_builder(name, WHILE);
  417. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
  418. if (op_desc == nullptr) {
  419. REPORT_CALL_ERROR("E19999", "Add dynamic input or output to op:%s(%s) failed",
  420. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  421. GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
  422. return FAILED;
  423. }
  424. NodePtr while_node = graph->AddNode(op_desc);
  425. if (while_node == nullptr) {
  426. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  427. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  428. GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed",
  429. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  430. return FAILED;
  431. }
  432. AddRePassNode(while_node);
  433. while_info.while_node = while_node;
  434. if (BuildWhileLink(while_info) != SUCCESS) {
  435. GELOGE(FAILED, "[Build][WhileLink] failed, node:%s.", while_node->GetName().c_str());
  436. return FAILED;
  437. }
  438. GELOGD("Create while node succ, name:%s.", name.c_str());
  439. return SUCCESS;
  440. }
  441. ///
  442. /// @brief Build while link-edge
  443. /// @param [in] while_info
  444. /// @return Status
  445. ///
  446. Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
  447. NodePtr while_node = while_info.while_node;
  448. GE_CHECK_NOTNULL(while_node);
  449. size_t input_num = while_info.data_inputs.size();
  450. for (size_t i = 0; i < input_num; i++) {
  451. InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
  452. GE_CHECK_NOTNULL(in_data_anchor);
  453. OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
  454. if (peer_out_anchor == nullptr) {
  455. continue;
  456. }
  457. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor),
  458. "[Add][DataEdge] %s:%d->%s:%zu failed.",
  459. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  460. while_node->GetName().c_str(), i);
  461. }
  462. size_t output_num = while_info.data_outputs.size();
  463. for (size_t i = 0; i < output_num; i++) {
  464. OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
  465. GE_CHECK_NOTNULL(out_data_anchor);
  466. for (auto &peer_in_anchor : while_info.data_outputs[i]) {
  467. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
  468. "[Add][DataEdge] %s:%zu->%s:%d failed.",
  469. while_node->GetName().c_str(), i + kWhileOutputIndex,
  470. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  471. }
  472. }
  473. InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
  474. GE_CHECK_NOTNULL(in_ctrl_anchor);
  475. for (auto &peer_out_anchor : while_info.ctrl_inputs) {
  476. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor),
  477. "[Add][CtrlEdge] %s->%s failed.",
  478. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  479. in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
  480. }
  481. OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
  482. GE_CHECK_NOTNULL(out_ctrl_anchor);
  483. for (auto &peer_in_anchor : while_info.ctrl_outputs) {
  484. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor),
  485. "[Add][CtrlEdge] %s->%s failed.",
  486. out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  487. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  488. }
  489. return SUCCESS;
  490. }
  491. ///
  492. /// @brief Build cond_graph for while_node
  493. /// @param [in&out] while_info
  494. /// @return ComputeGraphPtr
  495. ///
  496. ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
  497. std::string cond_name = while_info.for_body_name + "_Cond";
  498. CompleteGraphBuilder graph_builder(cond_name);
  499. // Add parent node
  500. graph_builder.SetParentNode(while_info.while_node);
  501. // Add Node
  502. const std::string mul_name = "Mul";
  503. graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
  504. const std::string less_name = "Less";
  505. graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
  506. // Set Input
  507. graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 })
  508. .SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 })
  509. .SetInput(kWhileRangeInputIndex, { less_name }, { 1 })
  510. .SetUselessInput(kWhileStartInputIndex)
  511. .SetUselessInput(kWhileDeltaInputIndex);
  512. size_t input_num = while_info.data_inputs.size();
  513. for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
  514. graph_builder.SetUselessInput(i);
  515. }
  516. // Add Output
  517. graph_builder.AddOutput(less_name, 0);
  518. // Add Edges
  519. graph_builder.AddDataLink(mul_name, 0, less_name, 0);
  520. // Add Input-Mapping
  521. std::map<uint32_t, uint32_t> input_mapping;
  522. for (size_t i = 0; i < input_num; i++) {
  523. input_mapping[i] = i;
  524. }
  525. graph_builder.SetInputMapping(input_mapping);
  526. graphStatus error_code = GRAPH_SUCCESS;
  527. std::string error_msg;
  528. ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
  529. if (cond_graph == nullptr) {
  530. REPORT_CALL_ERROR("E19999", "Build graph:%s failed", cond_name.c_str());
  531. GELOGE(FAILED, "[Build][CondGraph] failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  532. return nullptr;
  533. }
  534. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  535. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
  536. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
  537. while_info.while_cond = cond_graph;
  538. return cond_graph;
  539. }
  540. ///
  541. /// @brief Build body_graph for while_node
  542. /// @param [in&out] while_info
  543. /// @return ComputeGraphPtr
  544. ///
  545. ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
  546. std::string body_name = while_info.for_body_name + "_Body";
  547. CompleteGraphBuilder graph_builder(body_name);
  548. // Add parent node
  549. graph_builder.SetParentNode(while_info.while_node);
  550. // Add calculation nodes
  551. std::string const_name = "Const";
  552. std::string add_name_0 = "Add_0";
  553. std::string mul_name = "Mul";
  554. std::string add_name_1 = "Add_1";
  555. graph_builder.AddNode(CreateConstDesc(const_name, 1))
  556. .AddNode(CreateOpDesc(add_name_0, ADD, false))
  557. .AddNode(CreateOpDesc(mul_name, MUL, false))
  558. .AddNode(CreateOpDesc(add_name_1, ADD, false));
  559. // Add Subgraph node
  560. auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
  561. std::string sub_graph_node_name = while_info.for_body_name;
  562. uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
  563. auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
  564. graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
  565. // Set Input
  566. graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 })
  567. .SetUselessInput(kWhileAbsDeltaInputIndex)
  568. .SetUselessInput(kWhileRangeInputIndex)
  569. .SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 })
  570. .SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 });
  571. for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
  572. graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex });
  573. }
  574. // Add Outputs
  575. graph_builder.AddOutput(add_name_0, 0);
  576. for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
  577. graph_builder.AddOutput("Data_" + std::to_string(i), 0);
  578. }
  579. for (uint32_t i = 0; i < sub_graph_output_num; i++) {
  580. graph_builder.AddOutput(sub_graph_node_name, i);
  581. }
  582. // Add Edges
  583. graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
  584. .AddDataLink(mul_name, 0, add_name_1, 1)
  585. .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
  586. // Add Input-Mapping
  587. std::map<uint32_t, uint32_t> input_mapping;
  588. for (size_t i = 0; i < input_num; i++) {
  589. input_mapping[i] = i;
  590. }
  591. graph_builder.SetInputMapping(input_mapping);
  592. // Add outputMapping
  593. std::map<uint32_t, uint32_t> output_mapping;
  594. for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
  595. output_mapping[i] = i;
  596. }
  597. graph_builder.SetOutputMapping(output_mapping);
  598. graphStatus error_code = GRAPH_SUCCESS;
  599. std::string error_msg;
  600. ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
  601. if (body_graph == nullptr) {
  602. GELOGE(FAILED, "[Build][BodyGraph] failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  603. return nullptr;
  604. }
  605. NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
  606. if (sub_graph_node == nullptr) {
  607. GELOGE(FAILED, "[Get][Node] by name:%s failed.", sub_graph_node_name.c_str());
  608. return nullptr;
  609. }
  610. while_info.sub_graph_node = sub_graph_node;
  611. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  612. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
  613. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
  614. while_info.while_body = body_graph;
  615. return body_graph;
  616. }
  617. ///
  618. /// @brief Create op_desc for subgraph node
  619. /// @param [in] name
  620. /// @param [in] input_num
  621. /// @param [in] output_num
  622. /// @return OpDescPtr
  623. ///
  624. OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
  625. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  626. op_desc_builder.AddDynamicInput("args", input_num)
  627. .AddDynamicOutput("output", output_num);
  628. OpDescPtr op_desc = op_desc_builder.Build();
  629. if (op_desc == nullptr) {
  630. REPORT_CALL_ERROR("E19999", "Build op_desc:%s(%s) failed", name.c_str(), PARTITIONEDCALL);
  631. GELOGE(FAILED, "[Build][OpDesc] %s(%s) failed", name.c_str(), PARTITIONEDCALL);
  632. return nullptr;
  633. }
  634. size_t index = op_desc->GetSubgraphInstanceNames().size();
  635. op_desc->AddSubgraphName("f");
  636. op_desc->SetSubgraphInstanceName(index, name);
  637. return op_desc;
  638. }
  639. ///
  640. /// @brief Update InputMapping for for-body-graph
  641. /// @param [in] while_info
  642. /// @return Status
  643. ///
  644. Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
  645. ComputeGraphPtr for_body = while_info.for_body;
  646. GE_CHECK_NOTNULL(for_body);
  647. // index_of_cur_graph_node_input -> index_of_new_graph_node_input
  648. std::map<uint32_t, uint32_t> input_mapping;
  649. size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
  650. for (size_t i = 0; i < input_num; i++) {
  651. if (i == FOR_START_INPUT) {
  652. input_mapping[i] = i;
  653. } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
  654. continue;
  655. } else {
  656. input_mapping[i] = i - kIDiffValue;
  657. }
  658. }
  659. for_body->UpdateInputMapping(input_mapping);
  660. for_body->SetParentNode(while_info.sub_graph_node);
  661. for_body->SetParentGraph(while_info.while_body);
  662. return SUCCESS;
  663. }
  664. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示