You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_group_pass.cc 15 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/parallel_group_pass.h"
  17. #include "framework/common/debug/ge_log.h"
  18. #include "common/ge/ge_util.h"
  19. #include "framework/common/ge_inner_error_codes.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/utils/node_utils.h"
  23. namespace ge {
  24. namespace {
  25. const int32_t kMaxRecursionDepth = 10;
  26. const int64_t kLoopType = 1;
  27. }
  28. Status ParallelGroupPass::Run(ComputeGraphPtr graph) {
  29. GELOGD("ParallelGroupPass running");
  30. if (graph == nullptr) {
  31. GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass.");
  32. REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass.");
  33. return PARAM_INVALID;
  34. }
  35. if (graph->GetParentGraph() != nullptr) {
  36. GELOGD("Current graph %s is a subgraph, this pass only support root graph.",
  37. graph->GetName().c_str());
  38. return SUCCESS;
  39. }
  40. if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
  41. GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
  42. REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
  43. graph->GetName().c_str());
  44. return FAILED;
  45. }
  46. std::unordered_set<std::string> parallel_groups;
  47. int depth = 0;
  48. if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) {
  49. GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str());
  50. return INTERNAL_ERROR;
  51. }
  52. if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
  53. GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
  54. REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
  55. graph->GetName().c_str());
  56. return FAILED;
  57. }
  58. return SUCCESS;
  59. }
  60. Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth,
  61. std::unordered_set<std::string> &parallel_groups) {
  62. if (depth >= kMaxRecursionDepth) {
  63. GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
  64. REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
  65. return FAILED;
  66. }
  67. std::map<std::string, vector<NodePtr>> group_nodes;
  68. auto candidates = graph->GetDirectNode();
  69. auto root_graph = GraphUtils::FindRootGraph(graph);
  70. for (const auto &node : candidates) {
  71. OpDescPtr op_desc = node->GetOpDesc();
  72. if (op_desc == nullptr) {
  73. continue;
  74. }
  75. std::string group_name;
  76. if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) {
  77. group_nodes[group_name].push_back(node);
  78. parallel_groups.insert(group_name);
  79. GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str());
  80. }
  81. const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
  82. GE_CHECK_NOTNULL(root_graph);
  83. for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) {
  84. const auto &sub_graph = root_graph->GetSubgraph(*name_iter);
  85. GE_CHECK_NOTNULL(sub_graph);
  86. // if the pass add control edge for known and unknown graph, then the known graph will become unknown graph
  87. // the order between known and unknown graph is guaranteed by dynamic shape executor
  88. // so the parallel group pass do nothing for unknown graph
  89. if (sub_graph->GetGraphUnknownFlag()) {
  90. continue;
  91. }
  92. std::unordered_set<std::string> sub_parallel_groups;
  93. auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups);
  94. if (ret != SUCCESS) {
  95. GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str());
  96. return FAILED;
  97. }
  98. for (const auto &sub_parallel_group : sub_parallel_groups) {
  99. parallel_groups.insert(sub_parallel_group);
  100. group_nodes[sub_parallel_group].emplace_back(node);
  101. }
  102. }
  103. }
  104. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge;
  105. if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) {
  106. GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str());
  107. return FAILED;
  108. }
  109. for (const auto &itr : group_nodes) {
  110. const auto &nodes = itr.second;
  111. if (nodes.empty()) {
  112. continue;
  113. }
  114. NodePtr pre_node = nodes[0];
  115. NodePtr cur_node = nullptr;
  116. for (std::size_t i = 1; i < nodes.size(); i++) {
  117. cur_node = nodes[i];
  118. GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(),
  119. cur_node->GetName().c_str());
  120. if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
  121. GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
  122. pre_node->GetName().c_str(), cur_node->GetName().c_str());
  123. return FAILED;
  124. }
  125. pre_node = cur_node;
  126. }
  127. }
  128. return SUCCESS;
  129. }
  130. Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
  131. if (pre_node == cur_node) {
  132. GELOGD("Pre_node and cur_node are same, no need add anchor");
  133. return SUCCESS;
  134. }
  135. auto in_nodes = cur_node->GetInAllNodes();
  136. for (const auto &node : in_nodes) {
  137. if (pre_node == node) {
  138. GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(),
  139. cur_node->GetName().c_str());
  140. return SUCCESS;
  141. }
  142. }
  143. GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(),
  144. cur_node->GetName().c_str());
  145. return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(),
  146. cur_node->GetInControlAnchor());
  147. }
  148. Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
  149. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  150. std::string type;
  151. auto direct_nodes = graph->GetDirectNode();
  152. for (const auto &node : direct_nodes) {
  153. type = node->GetType();
  154. if (type != STREAMSWITCH) {
  155. continue;
  156. }
  157. if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) ||
  158. IsWhileStreamSwitch(node->GetOpDesc())) {
  159. continue;
  160. }
  161. std::vector<NodePtr> merge_nodes;
  162. std::set<NodePtr> group_nodes;
  163. std::set<std::string> stream_labels;
  164. FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels);
  165. if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) {
  166. GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s,"
  167. "merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(),
  168. merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
  169. REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s,"
  170. "merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(),
  171. merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
  172. return FAILED;
  173. }
  174. std::sort(merge_nodes.begin(), merge_nodes.end(),
  175. [] (NodePtr a, NodePtr b) -> bool {
  176. return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId());
  177. });
  178. NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
  179. GE_CHECK_NOTNULL(cast_node);
  180. if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes,
  181. cast_node, node,
  182. node_2_switch_merge) != SUCCESS) {
  183. GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
  184. REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.",
  185. graph->GetName().c_str());
  186. return FAILED;
  187. }
  188. }
  189. return SUCCESS;
  190. }
  191. void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes,
  192. std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels) {
  193. std::string type;
  194. std::deque<NodePtr> candidates;
  195. std::set<NodePtr> visited;
  196. candidates.push_back(stream_switch_node);
  197. while (!candidates.empty()) {
  198. NodePtr tmp_node = candidates.front();
  199. candidates.pop_front();
  200. for (const auto &out_node : tmp_node->GetOutAllNodes()) {
  201. type = out_node->GetType();
  202. if (type == STREAMMERGE) {
  203. merge_nodes.emplace_back(out_node);
  204. continue;
  205. }
  206. const auto &op = out_node->GetOpDesc();
  207. if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
  208. group_nodes.emplace(out_node);
  209. }
  210. if (visited.count(out_node) > 0) {
  211. continue;
  212. }
  213. candidates.push_back(out_node);
  214. visited.insert(out_node);
  215. std::string stream_label;
  216. if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
  217. stream_labels.insert(stream_label);
  218. }
  219. }
  220. }
  221. }
  222. Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
  223. const std::vector<NodePtr> &merge_nodes,
  224. const NodePtr &cast_node, const NodePtr &switch_node,
  225. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  226. for (const auto &group_node : group_nodes) {
  227. auto itr = node_2_switch_merge.find(group_node);
  228. if (itr != node_2_switch_merge.end()) {
  229. auto &tmp = itr->second;
  230. auto &switch_set = tmp.first;
  231. const auto &merge_node = tmp.second;
  232. GELOGD("Find group node: %s in switch %s and merge %s.",
  233. group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str());
  234. if (merge_node != merge_nodes.back()) {
  235. GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid",
  236. merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
  237. REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s,"
  238. "graph's structure is invalid",
  239. merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
  240. return FAILED;
  241. }
  242. switch_set.insert(cast_node);
  243. } else {
  244. node_2_switch_merge.emplace(group_node,
  245. std::make_pair(std::set<NodePtr>{cast_node}, merge_nodes.back()));
  246. }
  247. }
  248. return SUCCESS;
  249. }
  250. Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node,
  251. const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  252. auto pre_itr = node_2_switch_merge.find(pre_node);
  253. auto cur_itr = node_2_switch_merge.find(cur_node);
  254. if (pre_itr != node_2_switch_merge.end()) {
  255. if (cur_itr != node_2_switch_merge.end()) {
  256. const auto &pre_set = pre_itr->second.first;
  257. const auto &cur_set = cur_itr->second.first;
  258. if (!HasSameSwitch(pre_set, cur_set)) {
  259. pre_node = pre_itr->second.second;
  260. for (const auto &switch_node : cur_itr->second.first) {
  261. if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
  262. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  263. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  264. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  265. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  266. return FAILED;
  267. }
  268. }
  269. }
  270. return SUCCESS;
  271. } else {
  272. pre_node = pre_itr->second.second;
  273. return AddCtrlEdge(pre_node, cur_node);
  274. }
  275. } else {
  276. if (cur_itr != node_2_switch_merge.end()) {
  277. for (const auto &switch_node : cur_itr->second.first) {
  278. int64_t pre_id = pre_node->GetOpDesc()->GetId();
  279. int64_t switch_id = switch_node->GetOpDesc()->GetId();
  280. // avoid ring
  281. if (pre_id > switch_id) {
  282. auto merge_node = cur_itr->second.second;
  283. if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) {
  284. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  285. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  286. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  287. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  288. return FAILED;
  289. }
  290. } else {
  291. if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
  292. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  293. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  294. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  295. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  296. return FAILED;
  297. }
  298. }
  299. }
  300. } else {
  301. return AddCtrlEdge(pre_node, cur_node);
  302. }
  303. }
  304. return SUCCESS;
  305. }
  306. bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) {
  307. for (const auto &node1 : switch_set1) {
  308. auto itr = switch_set2.find(node1);
  309. if (itr != switch_set2.end()) {
  310. return true;
  311. }
  312. }
  313. return false;
  314. }
  315. bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) {
  316. return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG);
  317. }
  318. bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
  319. int64_t stream_switch_type = -1;
  320. return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
  321. stream_switch_type == kLoopType);
  322. }
  323. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示