You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_group_pass.cc 14 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/parallel_group_pass.h"
  17. #include "framework/common/debug/ge_log.h"
  18. #include "common/ge/ge_util.h"
  19. #include "framework/common/ge_inner_error_codes.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/utils/node_utils.h"
  23. namespace ge {
  24. namespace {
  25. const int32_t kMaxRecursionDepth = 10;
  26. const int64_t kLoopType = 1;
  27. }
  28. Status ParallelGroupPass::Run(ComputeGraphPtr graph) {
  29. GELOGD("ParallelGroupPass running");
  30. if (graph == nullptr) {
  31. GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass.");
  32. REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass.");
  33. return PARAM_INVALID;
  34. }
  35. if (graph->GetParentGraph() != nullptr) {
  36. GELOGD("Current graph %s is a subgraph, this pass only support root graph.",
  37. graph->GetName().c_str());
  38. return SUCCESS;
  39. }
  40. if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
  41. GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
  42. REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
  43. graph->GetName().c_str());
  44. return FAILED;
  45. }
  46. std::unordered_set<std::string> parallel_groups;
  47. int depth = 0;
  48. if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) {
  49. GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str());
  50. return INTERNAL_ERROR;
  51. }
  52. if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
  53. GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
  54. REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
  55. graph->GetName().c_str());
  56. return FAILED;
  57. }
  58. return SUCCESS;
  59. }
  60. Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth,
  61. std::unordered_set<std::string> &parallel_groups) {
  62. if (depth >= kMaxRecursionDepth) {
  63. GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
  64. REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
  65. return FAILED;
  66. }
  67. std::map<std::string, vector<NodePtr>> group_nodes;
  68. auto candidates = graph->GetDirectNode();
  69. auto root_graph = GraphUtils::FindRootGraph(graph);
  70. for (const auto &node : candidates) {
  71. OpDescPtr op_desc = node->GetOpDesc();
  72. if (op_desc == nullptr) {
  73. continue;
  74. }
  75. std::string group_name;
  76. if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) {
  77. group_nodes[group_name].push_back(node);
  78. parallel_groups.insert(group_name);
  79. GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str());
  80. }
  81. const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
  82. GE_CHECK_NOTNULL(root_graph);
  83. for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) {
  84. const auto &sub_graph = root_graph->GetSubgraph(*name_iter);
  85. GE_CHECK_NOTNULL(sub_graph);
  86. // if the pass add control edge for known and unknown graph, then the known graph will become unknown graph
  87. // the order between known and unknown graph is guaranteed by dynamic shape executor
  88. // so the parallel group pass do nothing for unknown graph
  89. if (sub_graph->GetGraphUnknownFlag()) {
  90. continue;
  91. }
  92. std::unordered_set<std::string> sub_parallel_groups;
  93. auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups);
  94. if (ret != SUCCESS) {
  95. GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str());
  96. return FAILED;
  97. }
  98. for (const auto &sub_parallel_group : sub_parallel_groups) {
  99. parallel_groups.insert(sub_parallel_group);
  100. group_nodes[sub_parallel_group].emplace_back(node);
  101. }
  102. }
  103. }
  104. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge;
  105. if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) {
  106. GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str());
  107. return FAILED;
  108. }
  109. for (const auto &itr : group_nodes) {
  110. const auto &nodes = itr.second;
  111. if (nodes.empty()) {
  112. continue;
  113. }
  114. NodePtr pre_node = nodes[0];
  115. NodePtr cur_node = nullptr;
  116. for (std::size_t i = 1; i < nodes.size(); i++) {
  117. cur_node = nodes[i];
  118. GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
  119. if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
  120. GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
  121. pre_node->GetName().c_str(), cur_node->GetName().c_str());
  122. return FAILED;
  123. }
  124. pre_node = cur_node;
  125. }
  126. }
  127. return SUCCESS;
  128. }
  129. Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
  130. if (pre_node == cur_node) {
  131. GELOGD("Pre_node and cur_node are same, no need add anchor");
  132. return SUCCESS;
  133. }
  134. auto in_nodes = cur_node->GetInAllNodes();
  135. for (const auto &node : in_nodes) {
  136. if (pre_node == node) {
  137. GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(),
  138. cur_node->GetName().c_str());
  139. return SUCCESS;
  140. }
  141. }
  142. GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
  143. return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor());
  144. }
  145. Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
  146. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  147. std::string type;
  148. auto direct_nodes = graph->GetDirectNode();
  149. for (const auto &node : direct_nodes) {
  150. type = node->GetType();
  151. if (type != STREAMSWITCH) {
  152. continue;
  153. }
  154. if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) ||
  155. IsWhileStreamSwitch(node->GetOpDesc())) {
  156. continue;
  157. }
  158. std::vector<NodePtr> merge_nodes;
  159. std::set<NodePtr> group_nodes;
  160. std::set<std::string> stream_labels;
  161. FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels);
  162. if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) {
  163. GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s,"
  164. "merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(),
  165. merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
  166. REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s,"
  167. "merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(),
  168. merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
  169. return FAILED;
  170. }
  171. std::sort(merge_nodes.begin(), merge_nodes.end(),
  172. [] (NodePtr a, NodePtr b) -> bool {
  173. return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId());
  174. });
  175. NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
  176. GE_CHECK_NOTNULL(cast_node);
  177. if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) {
  178. GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
  179. REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.",
  180. graph->GetName().c_str());
  181. return FAILED;
  182. }
  183. }
  184. return SUCCESS;
  185. }
  186. void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes,
  187. std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels) {
  188. std::string type;
  189. std::deque<NodePtr> candidates;
  190. std::set<NodePtr> visited;
  191. candidates.push_back(stream_switch_node);
  192. while (!candidates.empty()) {
  193. NodePtr tmp_node = candidates.front();
  194. candidates.pop_front();
  195. for (const auto &out_node : tmp_node->GetOutAllNodes()) {
  196. type = out_node->GetType();
  197. if (type == STREAMMERGE) {
  198. merge_nodes.emplace_back(out_node);
  199. continue;
  200. }
  201. const auto &op = out_node->GetOpDesc();
  202. if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
  203. group_nodes.emplace(out_node);
  204. }
  205. if (visited.count(out_node) > 0) {
  206. continue;
  207. }
  208. candidates.push_back(out_node);
  209. visited.insert(out_node);
  210. std::string stream_label;
  211. if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
  212. stream_labels.insert(stream_label);
  213. }
  214. }
  215. }
  216. }
  217. Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
  218. const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node,
  219. std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  220. for (const auto &group_node : group_nodes) {
  221. auto itr = node_2_switch_merge.find(group_node);
  222. if (itr != node_2_switch_merge.end()) {
  223. auto &tmp = itr->second;
  224. auto &switch_set = tmp.first;
  225. const auto &merge_node = tmp.second;
  226. GELOGD("Find group node: %s in switch %s and merge %s.",
  227. group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str());
  228. if (merge_node != merge_nodes.back()) {
  229. GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid",
  230. merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
  231. REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s,"
  232. "graph's structure is invalid",
  233. merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
  234. return FAILED;
  235. }
  236. switch_set.insert(cast_node);
  237. } else {
  238. node_2_switch_merge.emplace(group_node,
  239. std::make_pair(std::set<NodePtr>{cast_node}, merge_nodes.back()));
  240. }
  241. }
  242. return SUCCESS;
  243. }
  244. Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node,
  245. const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
  246. auto pre_itr = node_2_switch_merge.find(pre_node);
  247. auto cur_itr = node_2_switch_merge.find(cur_node);
  248. if (pre_itr != node_2_switch_merge.end()) {
  249. if (cur_itr != node_2_switch_merge.end()) {
  250. const auto &pre_set = pre_itr->second.first;
  251. const auto &cur_set = cur_itr->second.first;
  252. if (!HasSameSwitch(pre_set, cur_set)) {
  253. pre_node = pre_itr->second.second;
  254. for (const auto &switch_node : cur_itr->second.first) {
  255. if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
  256. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  257. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  258. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  259. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  260. return FAILED;
  261. }
  262. }
  263. }
  264. return SUCCESS;
  265. } else {
  266. pre_node = pre_itr->second.second;
  267. return AddCtrlEdge(pre_node, cur_node);
  268. }
  269. } else {
  270. if (cur_itr != node_2_switch_merge.end()) {
  271. for (const auto &switch_node : cur_itr->second.first) {
  272. int64_t pre_id = pre_node->GetOpDesc()->GetId();
  273. int64_t switch_id = switch_node->GetOpDesc()->GetId();
  274. // avoid ring
  275. if (pre_id > switch_id) {
  276. auto merge_node = cur_itr->second.second;
  277. if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) {
  278. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  279. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  280. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  281. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  282. return FAILED;
  283. }
  284. } else {
  285. if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
  286. GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  287. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  288. REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
  289. pre_node->GetName().c_str(), switch_node->GetName().c_str());
  290. return FAILED;
  291. }
  292. }
  293. }
  294. } else {
  295. return AddCtrlEdge(pre_node, cur_node);
  296. }
  297. }
  298. return SUCCESS;
  299. }
  300. bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) {
  301. for (const auto &node1 : switch_set1) {
  302. auto itr = switch_set2.find(node1);
  303. if (itr != switch_set2.end()) {
  304. return true;
  305. }
  306. }
  307. return false;
  308. }
  309. bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) {
  310. return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG);
  311. }
  312. bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
  313. int64_t stream_switch_type = -1;
  314. return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
  315. stream_switch_type == kLoopType);
  316. }
  317. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示