|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "graph/passes/parallel_group_pass.h"
-
- #include "framework/common/debug/ge_log.h"
- #include "common/ge/ge_util.h"
- #include "framework/common/ge_inner_error_codes.h"
- #include "graph/debug/ge_attr_define.h"
- #include "graph/utils/graph_utils.h"
- #include "graph/utils/node_utils.h"
-
- namespace ge {
- namespace {
- const int32_t kMaxRecursionDepth = 10;
- const int64_t kLoopType = 1;
- }
-
- Status ParallelGroupPass::Run(ComputeGraphPtr graph) {
- GELOGD("ParallelGroupPass running");
- if (graph == nullptr) {
- GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass.");
- REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass.");
- return PARAM_INVALID;
- }
-
- if (graph->GetParentGraph() != nullptr) {
- GELOGD("Current graph %s is a subgraph, this pass only support root graph.",
- graph->GetName().c_str());
- return SUCCESS;
- }
-
- if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
- GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
- REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
- graph->GetName().c_str());
- return FAILED;
- }
-
- std::unordered_set<std::string> parallel_groups;
- int depth = 0;
- if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) {
- GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str());
- return INTERNAL_ERROR;
- }
-
- if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
- GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
- REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
- graph->GetName().c_str());
- return FAILED;
- }
-
- return SUCCESS;
- }
-
- Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth,
- std::unordered_set<std::string> ¶llel_groups) {
- if (depth >= kMaxRecursionDepth) {
- GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
- REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth);
- return FAILED;
- }
- std::map<std::string, vector<NodePtr>> group_nodes;
- auto candidates = graph->GetDirectNode();
- auto root_graph = GraphUtils::FindRootGraph(graph);
- for (const auto &node : candidates) {
- OpDescPtr op_desc = node->GetOpDesc();
- if (op_desc == nullptr) {
- continue;
- }
- std::string group_name;
- if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) {
- group_nodes[group_name].push_back(node);
- parallel_groups.insert(group_name);
- GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str());
- }
-
- const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
- GE_CHECK_NOTNULL(root_graph);
- for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) {
- const auto &sub_graph = root_graph->GetSubgraph(*name_iter);
- GE_CHECK_NOTNULL(sub_graph);
- // if the pass add control edge for known and unknown graph, then the known graph will become unknown graph
- // the order between known and unknown graph is guaranteed by dynamic shape executor
- // so the parallel group pass do nothing for unknown graph
- if (sub_graph->GetGraphUnknownFlag()) {
- continue;
- }
- std::unordered_set<std::string> sub_parallel_groups;
- auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups);
- if (ret != SUCCESS) {
- GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str());
- return FAILED;
- }
- for (const auto &sub_parallel_group : sub_parallel_groups) {
- parallel_groups.insert(sub_parallel_group);
- group_nodes[sub_parallel_group].emplace_back(node);
- }
- }
- }
-
- std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge;
- if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) {
- GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str());
- return FAILED;
- }
-
- for (const auto &itr : group_nodes) {
- const auto &nodes = itr.second;
- if (nodes.empty()) {
- continue;
- }
- NodePtr pre_node = nodes[0];
- NodePtr cur_node = nullptr;
- for (std::size_t i = 1; i < nodes.size(); i++) {
- cur_node = nodes[i];
- GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
- if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
- GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
- pre_node->GetName().c_str(), cur_node->GetName().c_str());
- return FAILED;
- }
- pre_node = cur_node;
- }
- }
-
- return SUCCESS;
- }
-
- Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
- if (pre_node == cur_node) {
- GELOGD("Pre_node and cur_node are same, no need add anchor");
- return SUCCESS;
- }
- auto in_nodes = cur_node->GetInAllNodes();
- for (const auto &node : in_nodes) {
- if (pre_node == node) {
- GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(),
- cur_node->GetName().c_str());
- return SUCCESS;
- }
- }
- GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
- return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor());
- }
-
- Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
- std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
-
- std::string type;
- auto direct_nodes = graph->GetDirectNode();
- for (const auto &node : direct_nodes) {
- type = node->GetType();
- if (type != STREAMSWITCH) {
- continue;
- }
-
- if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) ||
- IsWhileStreamSwitch(node->GetOpDesc())) {
- continue;
- }
-
- std::vector<NodePtr> merge_nodes;
- std::set<NodePtr> group_nodes;
- std::set<std::string> stream_labels;
-
- FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels);
-
- if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) {
- GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s,"
- "merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(),
- merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
- REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s,"
- "merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(),
- merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
- return FAILED;
- }
-
- std::sort(merge_nodes.begin(), merge_nodes.end(),
- [] (NodePtr a, NodePtr b) -> bool {
- return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId());
- });
-
- NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
- GE_CHECK_NOTNULL(cast_node);
- if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) {
- GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
- REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.",
- graph->GetName().c_str());
- return FAILED;
- }
- }
-
- return SUCCESS;
- }
-
- void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes,
- std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels) {
- std::string type;
- std::deque<NodePtr> candidates;
- std::set<NodePtr> visited;
-
- candidates.push_back(stream_switch_node);
- while (!candidates.empty()) {
- NodePtr tmp_node = candidates.front();
- candidates.pop_front();
- for (const auto &out_node : tmp_node->GetOutAllNodes()) {
- type = out_node->GetType();
- if (type == STREAMMERGE) {
- merge_nodes.emplace_back(out_node);
- continue;
- }
- const auto &op = out_node->GetOpDesc();
- if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
- group_nodes.emplace(out_node);
- }
- if (visited.count(out_node) > 0) {
- continue;
- }
- candidates.push_back(out_node);
- visited.insert(out_node);
- std::string stream_label;
- if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
- stream_labels.insert(stream_label);
- }
- }
- }
- }
-
- Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
- const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node,
- std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
- for (const auto &group_node : group_nodes) {
- auto itr = node_2_switch_merge.find(group_node);
- if (itr != node_2_switch_merge.end()) {
- auto &tmp = itr->second;
- auto &switch_set = tmp.first;
- const auto &merge_node = tmp.second;
- GELOGD("Find group node: %s in switch %s and merge %s.",
- group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str());
- if (merge_node != merge_nodes.back()) {
- GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid",
- merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
- REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s,"
- "graph's structure is invalid",
- merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
- return FAILED;
- }
- switch_set.insert(cast_node);
- } else {
- node_2_switch_merge.emplace(group_node,
- std::make_pair(std::set<NodePtr>{cast_node}, merge_nodes.back()));
- }
- }
- return SUCCESS;
- }
-
- Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node,
- const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
- auto pre_itr = node_2_switch_merge.find(pre_node);
- auto cur_itr = node_2_switch_merge.find(cur_node);
- if (pre_itr != node_2_switch_merge.end()) {
- if (cur_itr != node_2_switch_merge.end()) {
- const auto &pre_set = pre_itr->second.first;
- const auto &cur_set = cur_itr->second.first;
- if (!HasSameSwitch(pre_set, cur_set)) {
- pre_node = pre_itr->second.second;
- for (const auto &switch_node : cur_itr->second.first) {
- if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
- GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
- pre_node->GetName().c_str(), switch_node->GetName().c_str());
- REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
- pre_node->GetName().c_str(), switch_node->GetName().c_str());
- return FAILED;
- }
- }
- }
- return SUCCESS;
- } else {
- pre_node = pre_itr->second.second;
- return AddCtrlEdge(pre_node, cur_node);
- }
- } else {
- if (cur_itr != node_2_switch_merge.end()) {
- for (const auto &switch_node : cur_itr->second.first) {
- int64_t pre_id = pre_node->GetOpDesc()->GetId();
- int64_t switch_id = switch_node->GetOpDesc()->GetId();
- // avoid ring
- if (pre_id > switch_id) {
- auto merge_node = cur_itr->second.second;
- if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) {
- GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
- pre_node->GetName().c_str(), switch_node->GetName().c_str());
- REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
- pre_node->GetName().c_str(), switch_node->GetName().c_str());
- return FAILED;
- }
- } else {
- if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
- GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
- pre_node->GetName().c_str(), switch_node->GetName().c_str());
- REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
- pre_node->GetName().c_str(), switch_node->GetName().c_str());
- return FAILED;
- }
- }
- }
- } else {
- return AddCtrlEdge(pre_node, cur_node);
- }
- }
- return SUCCESS;
- }
-
- bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) {
- for (const auto &node1 : switch_set1) {
- auto itr = switch_set2.find(node1);
- if (itr != switch_set2.end()) {
- return true;
- }
- }
- return false;
- }
-
- bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) {
- return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG);
- }
-
- bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
- int64_t stream_switch_type = -1;
- return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
- stream_switch_type == kLoopType);
- }
- } // namespace ge
|