You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mem_rw_conflict_optimize.cc 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #include <string>
  14. #include <vector>
  15. #include "common/ge/ge_util.h"
  16. #include "graph/common/omg_util.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/optimize/graph_optimize.h"
  19. #include "graph/utils/graph_utils.h"
  20. #include "graph/utils/node_utils.h"
  21. namespace {
  22. using namespace ge;
  23. const int kIdentityAnchorIndex = 0;
  24. const size_t kSerialStringVecSize = 4;
  25. const int kCaseReadOnly = 0;
  26. const int kCaseScopeWriteable = 2;
  27. const int kCaseWriteable = 3;
  28. const int kCaseInvalidRWType = 5;
  29. const char *const kInputMutable = "_input_mutable";
  30. // rw type of input.
  31. enum class InputRWType {
  32. kReadOnly, // Normal op input only read
  33. kWriteable, // Op like Assign/ApplyMomentum
  34. kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput
  35. kInvalidRWType
  36. };
  37. // rw type of output
  38. enum class OutputRWType {
  39. kReadOnly, // 1.const output 2.not ref output but has several peer output
  40. kSoftRead, // not ref output but only has one output node
  41. kWriteable, // ref output. Like Assign/ApplyMomentum
  42. kInvalidRWType
  43. };
  44. // input and output rw_type of one node. key is anchor_idx, value is rw_type
  45. struct NodeInputOutputRWType {
  46. map<uint32_t, InputRWType> input_rw_type_map;
  47. map<uint32_t, OutputRWType> output_rw_type_map;
  48. };
  49. // input and output rw_type of node in current graph
  50. thread_local map<string, NodeInputOutputRWType> node_rwtype_map_;
  51. ///
  52. /// @brief Convert input rw_type enum to string. For log print.
  53. /// @param rw_type
  54. /// @return rw_type_name
  55. ///
  56. static std::string InputRWTypeToSerialString(InputRWType rw_type) {
  57. const static char *names[kSerialStringVecSize] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
  58. return names[static_cast<int>(rw_type)];
  59. }
  60. ///
  61. /// @brief Convert output rw_type enum to string. For log print.
  62. /// @param rw_type
  63. /// @return rw_type_name
  64. ///
  65. static std::string OutputRWTypeToSerialString(OutputRWType rw_type) {
  66. const static char *names[kSerialStringVecSize] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
  67. return names[static_cast<int>(rw_type)];
  68. }
  69. OutputRWType GetSingleNodeOutputRWTypeByIndex(const Node &node, uint32_t index) {
  70. auto op_desc = node.GetOpDesc();
  71. if (op_desc == nullptr) {
  72. return OutputRWType::kInvalidRWType;
  73. }
  74. if (op_desc->GetType() == VARIABLE) {
  75. return OutputRWType::kWriteable;
  76. }
  77. // check if it is ref output
  78. auto input_names = op_desc->GetAllInputName();
  79. for (auto &input_name_2_idx : input_names) {
  80. if (op_desc->GetOutputNameByIndex(index) == input_name_2_idx.first) {
  81. return OutputRWType::kWriteable;
  82. }
  83. }
  84. // check if it is ref switch
  85. std::string type;
  86. if ((node.GetType() == FRAMEWORK_OP_TYPE) && AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)
  87. && (type == REFSWITCH)) {
  88. return OutputRWType::kWriteable;
  89. }
  90. if (op_desc->GetType() == CONSTANT || op_desc->GetType() == CONSTANTOP) {
  91. return OutputRWType::kReadOnly;
  92. }
  93. auto out_data_anchor = node.GetOutDataAnchor(index);
  94. if (out_data_anchor == nullptr) {
  95. return OutputRWType::kInvalidRWType;
  96. }
  97. if (out_data_anchor->GetPeerInDataNodesSize() > 1) {
  98. return OutputRWType::kReadOnly;
  99. } else {
  100. return OutputRWType::kSoftRead;
  101. }
  102. }
  103. ///
  104. /// @brief Get input rw_type of one node with sub graph. It will return rw_type after solve conflict scene.
  105. /// @param rw_type_set
  106. /// @return
  107. ///
  108. InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) {
  109. // for input rw type calc
  110. int total_rw_type = 0;
  111. for (const auto rw : rw_type_set) {
  112. total_rw_type += rw;
  113. }
  114. switch (total_rw_type) {
  115. case kCaseReadOnly:
  116. return InputRWType::kReadOnly; // all input rw type is readonly
  117. case kCaseScopeWriteable:
  118. return InputRWType::kScopeWriteable; // readonly 2 scope_writeable
  119. case kCaseWriteable:
  120. return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable
  121. case kCaseInvalidRWType:
  122. return InputRWType::kInvalidRWType; // writeable 2 scope_writeable
  123. default:
  124. return InputRWType::kInvalidRWType;
  125. }
  126. }
  127. bool IsSubgraphInputNode(const NodePtr &node) {
  128. if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != DATA) ||
  129. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  130. return false;
  131. }
  132. return true;
  133. }
  134. bool IsSubgraphOutputNode(const NodePtr &node) {
  135. if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != NETOUTPUT) ||
  136. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  137. return false;
  138. }
  139. return true;
  140. }
  141. NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) {
  142. if (src_node.GetOpDesc() == nullptr) {
  143. return nullptr;
  144. }
  145. static std::atomic_long identity_num(0);
  146. auto next_num = identity_num.fetch_add(1);
  147. // 1. create new identity op desc
  148. string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num);
  149. auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY);
  150. if (identity_opdesc == nullptr) {
  151. GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str());
  152. return nullptr;
  153. }
  154. auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx);
  155. // 2. add input_desc & output_desc for new identity
  156. Status ret = identity_opdesc->AddInputDesc("x", data_desc);
  157. if (ret != SUCCESS) {
  158. GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str());
  159. return nullptr;
  160. }
  161. ret = identity_opdesc->AddOutputDesc("y", data_desc);
  162. if (ret != SUCCESS) {
  163. GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str());
  164. return nullptr;
  165. }
  166. GELOGI("Insert new Identity node %s.", identity_name.c_str());
  167. auto graph = src_node.GetOwnerComputeGraph();
  168. if (graph == nullptr) {
  169. GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str());
  170. return nullptr;
  171. }
  172. return graph->AddNode(identity_opdesc);
  173. }
  174. OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) {
  175. auto op_desc = node.GetOpDesc();
  176. if (op_desc == nullptr) {
  177. return OutputRWType::kInvalidRWType;
  178. }
  179. if (op_desc->GetType() == WHILE) {
  180. return OutputRWType::kSoftRead;
  181. }
  182. vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames();
  183. if (subgraph_names.empty()) {
  184. // single node without sub graph
  185. return GetSingleNodeOutputRWTypeByIndex(node, index);
  186. } else {
  187. // node with sub graph
  188. auto output_node_vec = NodeUtils::GetSubgraphOutputNodes(node);
  189. auto output_rw_type = OutputRWType::kInvalidRWType;
  190. if (output_node_vec.size() == 1) {
  191. // find rw type from map.
  192. auto iter = node_rwtype_map_.find(output_node_vec.at(0)->GetName());
  193. if (iter == node_rwtype_map_.end()) {
  194. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  195. output_node_vec.at(0)->GetName().c_str());
  196. return OutputRWType::kInvalidRWType;
  197. }
  198. auto index_2_output_rw_type = iter->second.output_rw_type_map.find(index);
  199. if (index_2_output_rw_type == iter->second.output_rw_type_map.end()) {
  200. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  201. output_node_vec.at(0)->GetName().c_str());
  202. return OutputRWType::kInvalidRWType;
  203. }
  204. output_rw_type = index_2_output_rw_type->second;
  205. } else {
  206. output_rw_type = OutputRWType::kSoftRead;
  207. }
  208. // check peer input
  209. auto out_data_anchor = node.GetOutDataAnchor(index);
  210. if (out_data_anchor == nullptr) {
  211. return OutputRWType::kInvalidRWType;
  212. }
  213. if (out_data_anchor->GetPeerInDataNodesSize() > 1) {
  214. return OutputRWType::kReadOnly;
  215. } else {
  216. return output_rw_type;
  217. }
  218. }
  219. }
  220. InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) {
  221. auto op_desc = node.GetOpDesc();
  222. if (op_desc == nullptr) {
  223. return InputRWType::kInvalidRWType;
  224. }
  225. if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER
  226. || op_desc->GetType() == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE) {
  227. return InputRWType::kScopeWriteable;
  228. }
  229. // check if it is ref input
  230. auto output_names = op_desc->GetAllOutputName();
  231. for (auto &output_name_2_idx : output_names) {
  232. if (op_desc->GetInputNameByIndex(index) == output_name_2_idx.first) {
  233. return InputRWType::kWriteable;
  234. }
  235. }
  236. // check if it is ref switch
  237. std::string type;
  238. if ((node.GetType() == FRAMEWORK_OP_TYPE) && (AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type))
  239. && (type == REFSWITCH) && (index == 0)) {
  240. return InputRWType::kWriteable;
  241. }
  242. return InputRWType::kReadOnly;
  243. }
  244. InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) {
  245. auto op_desc = node.GetOpDesc();
  246. if (op_desc == nullptr) {
  247. return InputRWType::kInvalidRWType;
  248. }
  249. if (op_desc->GetType() == WHILE) {
  250. return InputRWType::kScopeWriteable;
  251. }
  252. vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames();
  253. if (subgraph_names.empty()) {
  254. // single node without sub graph
  255. return GetSingleNodeInputRWTypeByIndex(node, index);
  256. } else {
  257. // node with sub graph
  258. std::set<int> node_rw_type_set;
  259. auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index);
  260. // get all input data node in subgraph
  261. std::set<int> anchor_rw_type_set;
  262. for (const auto &data_node : data_node_vec) {
  263. // Data only has 1 out data anchor. Here just take first out data anchor. And index 0 is valid.
  264. auto out_data_anchor = data_node->GetOutDataAnchor(0);
  265. if (out_data_anchor == nullptr) {
  266. continue;
  267. }
  268. auto data_op_desc = data_node->GetOpDesc();
  269. if (data_op_desc == nullptr) {
  270. continue;
  271. }
  272. // find rw type from map.
  273. auto iter = node_rwtype_map_.find(data_op_desc->GetName());
  274. if (iter == node_rwtype_map_.end()) {
  275. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  276. data_op_desc->GetName().c_str());
  277. return InputRWType::kInvalidRWType;
  278. }
  279. auto input_rw_type = iter->second.input_rw_type_map.find(out_data_anchor->GetIdx());
  280. if (input_rw_type == iter->second.input_rw_type_map.end()) {
  281. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  282. data_op_desc->GetName().c_str());
  283. return InputRWType::kInvalidRWType;
  284. }
  285. anchor_rw_type_set.emplace(static_cast<int>(input_rw_type->second));
  286. }
  287. return GetInputRwTypeInConflict(anchor_rw_type_set);
  288. }
  289. }
  290. Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) {
  291. for (const auto &node : sub_graph->GetDirectNode()) {
  292. GE_CHECK_NOTNULL(node);
  293. GE_CHECK_NOTNULL(node->GetOpDesc());
  294. std::set<int> anchor_rw_type_set;
  295. if (node->GetType() == DATA) {
  296. // calc all input_rw_type of peer output , as input_rw_type of DATA. Index 0 is valid.
  297. auto anchor_2_node_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, 0);
  298. for (const auto anchor_2_node_pair : anchor_2_node_vec) {
  299. auto input_rw_type = GetInputRWTypeByIndex(*anchor_2_node_pair.second, anchor_2_node_pair.first->GetIdx());
  300. GELOGD("Input rw type of Node %s %dth input anchor is %s", anchor_2_node_pair.second->GetName().c_str(),
  301. anchor_2_node_pair.first->GetIdx(), InputRWTypeToSerialString(input_rw_type).c_str());
  302. anchor_rw_type_set.emplace(static_cast<int>(input_rw_type));
  303. }
  304. auto anchor_rw_type = GetInputRwTypeInConflict(anchor_rw_type_set);
  305. GELOGD("Input rw type of Node %s is %s", node->GetName().c_str(),
  306. InputRWTypeToSerialString(anchor_rw_type).c_str());
  307. map<uint32_t, InputRWType> input_rw_type_map{std::make_pair(0, anchor_rw_type)};
  308. NodeInputOutputRWType data_rw_type{input_rw_type_map};
  309. node_rwtype_map_.emplace(std::make_pair(node->GetName(), data_rw_type));
  310. }
  311. if (node->GetType() == NETOUTPUT) {
  312. // calc all output_rw_type of peer input , as output_rw_type of DATA
  313. map<uint32_t, OutputRWType> output_rw_type_map;
  314. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  315. GE_CHECK_NOTNULL(in_data_anchor);
  316. auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
  317. GE_CHECK_NOTNULL(pre_out_anchor);
  318. auto pre_node = pre_out_anchor->GetOwnerNode();
  319. GE_CHECK_NOTNULL(pre_node);
  320. auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
  321. GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(),
  322. pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str());
  323. auto parent_node = sub_graph->GetParentNode();
  324. if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) {
  325. // insert identity
  326. auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
  327. GE_CHECK_NOTNULL(identity_node);
  328. auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
  329. if (ret != SUCCESS) {
  330. GELOGE(ret, "Fail to insert identity");
  331. return ret;
  332. }
  333. GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
  334. pre_node->GetName().c_str(), node->GetName().c_str());
  335. pre_output_rw_type = OutputRWType::kSoftRead;
  336. }
  337. output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), pre_output_rw_type));
  338. }
  339. NodeInputOutputRWType output_rw_type{{}, output_rw_type_map};
  340. node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type));
  341. }
  342. }
  343. return SUCCESS;
  344. }
  345. ///
  346. /// @brief Reverse traversal all subgraph and mark rw_type for Data/Netoutput.
  347. /// @param sub_graph_vecgs
  348. ///
  349. Status MarkRWTypeForAllSubgraph(const vector<ComputeGraphPtr> &sub_graph_vec) {
  350. for (auto iter = sub_graph_vec.rbegin(); iter != sub_graph_vec.rend(); ++iter) {
  351. auto parent_node = (*iter)->GetParentNode();
  352. if (parent_node == nullptr) {
  353. GELOGD("Current sub graph has no parent node. Ignore it.");
  354. continue;
  355. }
  356. if (parent_node->GetType() == WHILE) {
  357. continue;
  358. }
  359. auto ret = MarkRWTypeForSubgraph(*iter);
  360. if (ret != SUCCESS) {
  361. return ret;
  362. }
  363. }
  364. return SUCCESS;
  365. }
  366. ///
  367. /// @brief Check identity is near subgraph.
  368. /// Eg. As output of Data node in subgraph
  369. /// or as input of Netoutput of subgraph
  370. /// or as input of one node with subgraph
  371. /// or as output of one node with subgraph
  372. /// @param node
  373. /// @return is_near_subgraph
  374. ///
  375. bool CheckIdentityIsNearSubgraph(const Node &node) {
  376. for (const auto &in_node : node.GetInDataNodes()) {
  377. auto in_node_opdesc = in_node->GetOpDesc();
  378. if (in_node_opdesc == nullptr) {
  379. continue;
  380. }
  381. // near entrance of subgraph
  382. if (IsSubgraphInputNode(in_node)) {
  383. return true;
  384. }
  385. // near subgraph
  386. if (!in_node_opdesc->GetSubgraphInstanceNames().empty()) {
  387. return true;
  388. }
  389. }
  390. for (const auto &out_node : node.GetOutDataNodes()) {
  391. auto out_node_opdesc = out_node->GetOpDesc();
  392. if (out_node_opdesc == nullptr) {
  393. continue;
  394. }
  395. // near output of subgraph
  396. if (IsSubgraphOutputNode(out_node)) {
  397. return true;
  398. }
  399. // near subgraph
  400. if (!out_node_opdesc->GetSubgraphInstanceNames().empty()) {
  401. return true;
  402. }
  403. }
  404. return false;
  405. }
  406. enum ConflictResult { DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY };
  407. vector<vector<ConflictResult>> output_2_input_rwtype = {{DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY},
  408. {DO_NOTHING, WRONG_GRAPH, DO_NOTHING},
  409. {DO_NOTHING, DO_NOTHING, INSERT_IDENTITY}};
  410. ConflictResult GetConflictResultBetweenNode(const OutputRWType output_rw_type, const InputRWType input_rw_type) {
  411. if (output_rw_type == OutputRWType::kInvalidRWType || input_rw_type == InputRWType::kInvalidRWType) {
  412. return WRONG_GRAPH;
  413. }
  414. auto n = static_cast<int>(output_rw_type);
  415. auto m = static_cast<int>(input_rw_type);
  416. // no need to check index or container, because container and index is all defined.
  417. return output_2_input_rwtype[n][m];
  418. }
  419. ///
  420. /// @brief Keep identity_node which near subgraph or has multi output
  421. /// @param node
  422. /// @return
  423. ///
  424. Status RemoveNoUseIdentity(const NodePtr &node) {
  425. if (node->GetInDataNodes().empty() || node->GetOutDataNodesSize() > 1) {
  426. return SUCCESS;
  427. }
  428. if (node->GetOutDataNodesSize() == 1 && node->GetOutDataNodes().at(0)->GetType() == STREAMMERGE) {
  429. return SUCCESS;
  430. }
  431. if (CheckIdentityIsNearSubgraph(*node)) {
  432. return SUCCESS;
  433. }
  434. GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex));
  435. auto pre_out_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor();
  436. GE_CHECK_NOTNULL(pre_out_anchor);
  437. auto pre_node = pre_out_anchor->GetOwnerNode();
  438. auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
  439. auto anchor_2_outnode_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kIdentityAnchorIndex);
  440. ConflictResult conflict_result = WRONG_GRAPH;
  441. if (!anchor_2_outnode_vec.empty()) {
  442. auto anchor_2_outnode = anchor_2_outnode_vec.at(0);
  443. auto peer_input_rw_type = GetInputRWTypeByIndex(*anchor_2_outnode.second, anchor_2_outnode.first->GetIdx());
  444. GELOGD("Pre Node %s %dth output rw type is %s, peer node %s %dth input rw type is %s.", pre_node->GetName().c_str(),
  445. pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(),
  446. anchor_2_outnode.second->GetName().c_str(), anchor_2_outnode.first->GetIdx(),
  447. InputRWTypeToSerialString(peer_input_rw_type).c_str());
  448. conflict_result = GetConflictResultBetweenNode(pre_output_rw_type, peer_input_rw_type);
  449. } else {
  450. // identity node has no out data node, it can be removed
  451. conflict_result = DO_NOTHING;
  452. }
  453. if (conflict_result != DO_NOTHING) {
  454. return SUCCESS;
  455. }
  456. GELOGI("No need insert Identity. Node %s need to remove.", node->GetName().c_str());
  457. auto ret = GraphUtils::IsolateNode(node, {0});
  458. if (ret != SUCCESS) {
  459. GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str());
  460. return ret;
  461. }
  462. ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node);
  463. if (ret != SUCCESS) {
  464. GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str());
  465. return ret;
  466. }
  467. GELOGI("Pre node is %s and %dth output rw type is %s. Isolate and remove Identity node %s.",
  468. pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(),
  469. node->GetName().c_str());
  470. return SUCCESS;
  471. }
  472. Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &peer_in_data_anchor,
  473. const OutDataAnchorPtr &pre_out_data_anchor, NodePtr &pre_node) {
  474. // 1.check peer in node RW type.
  475. GE_CHECK_NOTNULL(peer_in_data_anchor);
  476. auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
  477. GE_CHECK_NOTNULL(peer_in_data_node);
  478. auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx());
  479. auto ret = out_data_anchor->Unlink(peer_in_data_anchor);
  480. auto old_identity = out_data_anchor->GetOwnerNode();
  481. if (ret != SUCCESS) {
  482. GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(),
  483. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  484. return ret;
  485. }
  486. if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) {
  487. auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx());
  488. GE_CHECK_NOTNULL(new_identity);
  489. if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS
  490. || GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) {
  491. GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s",
  492. pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
  493. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  494. return INTERNAL_ERROR;
  495. }
  496. // 2. copy in-control-edge from dst to Identity
  497. if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) {
  498. GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(),
  499. new_identity->GetName().c_str());
  500. return INTERNAL_ERROR;
  501. }
  502. GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(),
  503. InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
  504. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  505. } else {
  506. // copy control edge to pre and peer node
  507. if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS
  508. || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) {
  509. GELOGW("Fail to copy control edge from node %s.", old_identity->GetName().c_str());
  510. return FAILED;
  511. }
  512. // link identity pre node to next node directly
  513. if (GraphUtils::AddEdge(pre_out_data_anchor, peer_in_data_anchor) != SUCCESS) {
  514. GELOGW("Fail to link data edge from node %s to %s.", pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
  515. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  516. return FAILED;
  517. }
  518. GELOGI("Node %s input rw type is %s, link data edge from Identity input node %s to out node %s directly.",
  519. peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(),
  520. pre_node->GetName().c_str(), peer_in_data_node->GetName().c_str());
  521. }
  522. return SUCCESS;
  523. }
  524. Status SplitIdentity(const NodePtr &node) {
  525. GE_CHECK_NOTNULL(node);
  526. auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex);
  527. GE_CHECK_NOTNULL(out_data_anchor);
  528. if (out_data_anchor->GetPeerInDataNodesSize() <= 1) {
  529. return SUCCESS;
  530. }
  531. // get pre node and next node of identity
  532. GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex));
  533. auto pre_out_data_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor();
  534. GE_CHECK_NOTNULL(pre_out_data_anchor);
  535. auto pre_node = pre_out_data_anchor->GetOwnerNode();
  536. GE_CHECK_NOTNULL(pre_node);
  537. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  538. Status ret = SplitIdentityAlongAnchor(out_data_anchor, peer_in_data_anchor, pre_out_data_anchor, pre_node);
  539. if (ret != SUCCESS) {
  540. GELOGE(ret, "Split identity node along anchor failed.");
  541. return ret;
  542. }
  543. }
  544. // 2.isolate Identity node with no data output
  545. if (node->GetOutDataNodesSize() == 0) {
  546. Status ret = GraphUtils::IsolateNode(node, {});
  547. if (ret != SUCCESS) {
  548. GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str());
  549. return FAILED;
  550. }
  551. ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node);
  552. if (ret != SUCCESS) {
  553. GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str());
  554. return FAILED;
  555. }
  556. GELOGI("IsolateAndDelete identity node %s.", node->GetName().c_str());
  557. }
  558. return SUCCESS;
  559. }
  560. Status InsertIdentityAsNeeded(const NodePtr &node) {
  561. auto op_desc = node->GetOpDesc();
  562. GE_CHECK_NOTNULL(op_desc);
  563. if (node->GetOutDataNodesSize() == 0) {
  564. return SUCCESS;
  565. }
  566. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  567. GE_CHECK_NOTNULL(out_data_anchor);
  568. auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx());
  569. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  570. GE_CHECK_NOTNULL(peer_in_data_anchor);
  571. auto peer_in_node = peer_in_data_anchor->GetOwnerNode();
  572. GE_CHECK_NOTNULL(peer_in_node);
  573. auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx());
  574. GELOGD("Node %s output rw type is %s, Node %s input rw type is %s", node->GetName().c_str(),
  575. OutputRWTypeToSerialString(output_rw_type).c_str(), peer_in_node->GetName().c_str(),
  576. InputRWTypeToSerialString(input_rw_type).c_str());
  577. auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type);
  578. switch (conflict_result) {
  579. case DO_NOTHING:
  580. case WRONG_GRAPH:
  581. GELOGD("No need insert Identity.");
  582. continue;
  583. case INSERT_IDENTITY:
  584. auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx());
  585. if (identity_node == nullptr) {
  586. GELOGE(FAILED, "Create identity node failed.");
  587. return FAILED;
  588. }
  589. auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node);
  590. if (ret != GRAPH_SUCCESS) {
  591. GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(),
  592. peer_in_node->GetName().c_str());
  593. return INTERNAL_ERROR;
  594. }
  595. GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(),
  596. peer_in_node->GetName().c_str());
  597. continue;
  598. }
  599. }
  600. }
  601. return SUCCESS;
  602. }
  603. Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
  604. for (const auto &node : compute_graph->GetDirectNode()) {
  605. // op_desc of node should not be null
  606. const auto &op_desc = node->GetOpDesc();
  607. bool mutable_input_flag = false;
  608. if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) {
  609. GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str());
  610. continue;
  611. }
  612. std::set<OutDataAnchorPtr> pre_out_anchor_set;
  613. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  614. auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
  615. GE_CHECK_NOTNULL(pre_out_anchor);
  616. if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
  617. pre_out_anchor_set.emplace(pre_out_anchor);
  618. continue;
  619. }
  620. // need insert identity
  621. auto pre_node = pre_out_anchor->GetOwnerNode();
  622. auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
  623. GE_CHECK_NOTNULL(identity_node);
  624. auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
  625. GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
  626. GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
  627. pre_node->GetName().c_str(), node->GetName().c_str());
  628. }
  629. }
  630. return SUCCESS;
  631. }
  632. } // namespace
  633. namespace ge {
  634. Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_conflict) {
  635. node_rwtype_map_.clear();
  636. auto sub_graph_vec = compute_graph->GetAllSubgraphs();
  637. if (sub_graph_vec.empty()) {
  638. GELOGD("No sub graph here. Ignore memory conflict handle.");
  639. return SUCCESS;
  640. }
  641. // 1.loop all subgraph, mark rw type from inside to outside
  642. Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
  643. if (ret != SUCCESS) {
  644. GELOGE(ret, "Fail to mark rw type for subgraph.");
  645. return ret;
  646. }
  647. has_conflict = false;
  648. for (const auto &node : compute_graph->GetAllNodes()) {
  649. auto op_desc = node->GetOpDesc();
  650. GE_CHECK_NOTNULL(op_desc);
  651. if (node->GetOutDataNodesSize() == 0) {
  652. return SUCCESS;
  653. }
  654. if (node->GetType() == WHILE) {
  655. return SUCCESS;
  656. }
  657. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  658. GE_CHECK_NOTNULL(out_data_anchor);
  659. auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx());
  660. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  661. GE_CHECK_NOTNULL(peer_in_data_anchor);
  662. auto peer_in_node = peer_in_data_anchor->GetOwnerNode();
  663. GE_CHECK_NOTNULL(peer_in_node);
  664. if (peer_in_node->GetType() == WHILE) {
  665. return SUCCESS;
  666. }
  667. auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx());
  668. auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type);
  669. switch (conflict_result) {
  670. case DO_NOTHING:
  671. GELOGD("No rw conflict.");
  672. continue;
  673. case WRONG_GRAPH:
  674. has_conflict = true;
  675. GELOGI("Node %s output rw type is %s, next node %s input_rw_type is %s.It is wrong graph.",
  676. node->GetName().c_str(), OutputRWTypeToSerialString(output_rw_type).c_str(),
  677. peer_in_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str());
  678. return SUCCESS;
  679. case INSERT_IDENTITY:
  680. GELOGD("There is rw conflict. It will handle later.");
  681. continue;
  682. }
  683. }
  684. }
  685. }
  686. return SUCCESS;
  687. }
  688. Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) {
  689. GE_DUMP(compute_graph, "BeforeHandleMemConflict");
  690. node_rwtype_map_.clear();
  691. auto sub_graph_vec = compute_graph->GetAllSubgraphs();
  692. if (sub_graph_vec.empty()) {
  693. // only root graph, to handle allreduce servral input from one output anchor
  694. return HandleAllreduceDuplicateInput(compute_graph);
  695. }
  696. // 1.loop all subgraph, mark rw type from inside to outside
  697. Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
  698. if (ret != SUCCESS) {
  699. GELOGE(ret, "Fail to mark rw type for subgraph.");
  700. return ret;
  701. }
  702. // 2.loop all node, including node in subgraph and handle memory rw conflict
  703. for (auto &node : compute_graph->GetAllNodes()) {
  704. // ignore while subgraph node
  705. const auto parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  706. if ((parent_node != nullptr) && (kWhileOpTypes.count(parent_node->GetType()) > 0)) {
  707. continue;
  708. }
  709. // ignore data / netoutput of subgraph
  710. if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
  711. continue;
  712. }
  713. if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
  714. continue;
  715. }
  716. bool identity_reserved = false;
  717. AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved);
  718. if (identity_reserved) {
  719. GELOGD("Identity [%s] need to be reserved", node->GetName().c_str());
  720. continue;
  721. }
  722. if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) {
  723. // split identity
  724. ret = SplitIdentity(node);
  725. if (ret != SUCCESS) {
  726. GELOGE(ret, "Fail to split identity node %s.", node->GetName().c_str());
  727. return ret;
  728. }
  729. // remove no use identity
  730. ret = RemoveNoUseIdentity(node);
  731. if (ret != SUCCESS) {
  732. GELOGE(ret, "Fail to remove useless identity node %s.", node->GetName().c_str());
  733. return ret;
  734. }
  735. }
  736. // insert Identity
  737. ret = InsertIdentityAsNeeded(node);
  738. if (ret != SUCCESS) {
  739. GELOGE(ret, "Fail to insert Identity node.");
  740. return ret;
  741. }
  742. }
  743. GE_DUMP(compute_graph, "AfterHandleMemConflict");
  744. return SUCCESS;
  745. }
  746. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示