From cf101e0aa24d4b569a310a6bdc652a10a7106724 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E?= Date: Tue, 9 Mar 2021 14:32:20 +0800 Subject: [PATCH] parallel group --- ge/CMakeLists.txt | 2 + ge/graph/build/logical_stream_allocator.cc | 43 +++ ge/graph/build/logical_stream_allocator.h | 7 + ge/graph/manager/graph_manager.cc | 7 + ge/graph/passes/next_iteration_pass.cc | 9 +- ge/graph/passes/parallel_group_pass.cc | 354 ++++++++++++++++++ ge/graph/passes/parallel_group_pass.h | 53 +++ .../passes/switch_to_stream_switch_pass.cc | 7 + tests/ut/ge/CMakeLists.txt | 3 + .../logical_stream_allocator_unittest.cc | 43 +++ .../passes/parallel_group_pass_unittest.cc | 304 +++++++++++++++ 11 files changed, 831 insertions(+), 1 deletion(-) create mode 100644 ge/graph/passes/parallel_group_pass.cc create mode 100644 ge/graph/passes/parallel_group_pass.h create mode 100644 tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index c29936bb..1a17c427 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -320,6 +320,7 @@ set(TRAIN_SRC_LIST "graph/passes/variable_ref_useless_control_out_delete_pass.cc" "graph/passes/end_of_sequence_add_control_pass.cc" "graph/passes/memcpy_addr_async_pass.cc" + "graph/passes/parallel_group_pass.cc" "graph/passes/set_input_output_offset_pass.cc" "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" @@ -607,6 +608,7 @@ set(INFER_SRC_LIST "graph/passes/hccl_group_pass.cc" "graph/passes/memcpy_addr_async_pass.cc" "graph/passes/set_input_output_offset_pass.cc" + "graph/passes/parallel_group_pass.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 3bc29b70..bfa1bb1a 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -376,6 +376,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { + std::map> stream_op_map; + for (const SubgraphPtr &subgraph : subgraphs) { + auto compute_graph = subgraph->subgraph_info.GetSubGraph(); + for (const NodePtr &node : compute_graph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + int64_t op_desc_stream_id = op_desc->GetStreamId(); + stream_op_map[op_desc_stream_id].push_back(op_desc); + } + } + } + for (const auto &itr : stream_op_map) { + if (itr.first == kInvalidStream) { + continue; + } + std::map group_2_stream_id; + for (const auto &op_desc : itr.second) { + std::string group_name; + if (!AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) { + GELOGE(FAILED, "[GetAttr][OpDesc]Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); + return FAILED; + } + const auto &itr = group_2_stream_id.find(group_name); + int64_t new_stream_id = kInvalidStream; + int64_t old_stream_id = op_desc->GetStreamId(); + if (itr != group_2_stream_id.end()) { + new_stream_id = itr->second; + } else { + new_stream_id = context.next_stream++; + group_2_stream_id[group_name] = new_stream_id; + } + op_desc->SetStreamId(new_stream_id); + GELOGD("Node %s assigned stream %ld from stream %ld.", + op_desc->GetName().c_str(), new_stream_id, old_stream_id); + } + } + return SUCCESS; +} + int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { set stream_ids; @@ -665,6 +707,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); } diff --git a/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h index b9aec611..2a94c254 100644 --- a/ge/graph/build/logical_stream_allocator.h +++ b/ge/graph/build/logical_stream_allocator.h @@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass { Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; }; +// assign stream by parallel group +class UpdateForParallelGroupPass : public LogicalStreamPass { + public: + STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass); + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; +}; + // Update the stream of subgraphs to nodes. class UpdateForSkippedEnginePass : public LogicalStreamPass { public: diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 37209aae..cc3420df 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -93,6 +93,7 @@ #include "graph/passes/global_step_insert_pass.h" #include "graph/passes/memcpy_addr_async_pass.h" #include "graph/passes/hccl_continuous_memcpy_pass.h" +#include "graph/passes/parallel_group_pass.h" #include "graph/build/label_allocator.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" @@ -2381,6 +2382,12 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); + // Handle parallel group . + GE_TIMESTAMP_START(ParallelGroup); + ParallelGroupPass parallel_group_pass; + GE_CHK_STATUS_RET(parallel_group_pass.Run(compute_graph), "Handle parallel group failed."); + GE_TIMESTAMP_END(ParallelGroup, "ParallelGroupPass::Run."); + // After while sub graph handle, mark all node rw type auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); if (result != SUCCESS) { diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index cf46f09d..8d76da32 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -22,6 +22,10 @@ using std::string; namespace ge { +namespace { +const int64_t kLoopType = 1; +} + Status NextIterationPass::Run(ComputeGraphPtr graph) { GELOGD("NextIterationPass Enter"); /// Enter-----------+ @@ -121,7 +125,10 @@ Status NextIterationPass::FindWhileGroups() { if (switch_node == nullptr) { continue; } - + if (!AttrUtils::SetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, kLoopType)) { + GELOGE(INTERNAL_ERROR, "set int failed"); + return INTERNAL_ERROR; + } NodePtr loop_cond = nullptr; if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc new file mode 100644 index 00000000..0d033fbf --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.cc @@ -0,0 +1,354 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/parallel_group_pass.h" + +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +namespace ge { +namespace { +const int32_t kMaxRecursionDepth = 10; +const int64_t kLoopType = 1; +} + +Status ParallelGroupPass::Run(ComputeGraphPtr graph) { + GELOGD("ParallelGroupPass running"); + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass."); + REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass."); + return PARAM_INVALID; + } + + if (graph->GetParentGraph() != nullptr) { + GELOGD("Current graph %s is a subgraph, this pass only support root graph.", + graph->GetName().c_str()); + return SUCCESS; + } + + if (graph->TopologicalSorting() != GRAPH_SUCCESS) { + GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.", + graph->GetName().c_str()); + return FAILED; + } + + std::unordered_set parallel_groups; + int depth = 0; + if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (graph->TopologicalSorting() != GRAPH_SUCCESS) { + GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.", + graph->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth, + std::unordered_set ¶llel_groups) { + if (depth >= kMaxRecursionDepth) { + GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth); + REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth); + return FAILED; + } + std::map> group_nodes; + auto candidates = graph->GetDirectNode(); + auto root_graph = GraphUtils::FindRootGraph(graph); + for (const auto &node : candidates) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + std::string group_name; + if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) { + group_nodes[group_name].push_back(node); + parallel_groups.insert(group_name); + GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str()); + } + + const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); + GE_CHECK_NOTNULL(root_graph); + for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { + const auto &sub_graph = root_graph->GetSubgraph(*name_iter); + GE_CHECK_NOTNULL(sub_graph); + // if the pass add control edge for known and unknown graph, then the known graph will become unknown graph + // the order between known and unknown graph is guaranteed by dynamic shape executor + // so the parallel group pass do nothing for unknown graph + if (sub_graph->GetGraphUnknownFlag()) { + continue; + } + std::unordered_set sub_parallel_groups; + auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str()); + return FAILED; + } + for (const auto &sub_parallel_group : sub_parallel_groups) { + parallel_groups.insert(sub_parallel_group); + group_nodes[sub_parallel_group].emplace_back(node); + } + } + } + + std::map, NodePtr>> node_2_switch_merge; + if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) { + GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str()); + return FAILED; + } + + for (const auto &itr : group_nodes) { + const auto &nodes = itr.second; + if (nodes.empty()) { + continue; + } + NodePtr pre_node = nodes[0]; + NodePtr cur_node = nullptr; + for (std::size_t i = 1; i < nodes.size(); i++) { + cur_node = nodes[i]; + GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) { + GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.", + pre_node->GetName().c_str(), cur_node->GetName().c_str()); + return FAILED; + } + pre_node = cur_node; + } + } + + return SUCCESS; +} + +Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { + if (pre_node == cur_node) { + GELOGD("Pre_node and cur_node are same, no need add anchor"); + return SUCCESS; + } + auto in_nodes = cur_node->GetInAllNodes(); + for (const auto &node : in_nodes) { + if (pre_node == node) { + GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + return SUCCESS; + } + } + GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), + cur_node->GetInControlAnchor()); +} + +Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, + std::map, NodePtr>> &node_2_switch_merge) { + + std::string type; + auto direct_nodes = graph->GetDirectNode(); + for (const auto &node : direct_nodes) { + type = node->GetType(); + if (type != STREAMSWITCH) { + continue; + } + + if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) || + IsWhileStreamSwitch(node->GetOpDesc())) { + continue; + } + + std::vector merge_nodes; + std::set group_nodes; + std::set stream_labels; + + FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels); + + if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) { + GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s," + "merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(), + merge_nodes.size(), stream_labels.size(), graph->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s," + "merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(), + merge_nodes.size(), stream_labels.size(), graph->GetName().c_str()); + return FAILED; + } + + std::sort(merge_nodes.begin(), merge_nodes.end(), + [] (NodePtr a, NodePtr b) -> bool { + return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId()); + }); + + NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0); + GE_CHECK_NOTNULL(cast_node); + if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, + cast_node, node, + node_2_switch_merge) != SUCCESS) { + GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", + graph->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set &group_nodes, + std::vector &merge_nodes, std::set &stream_labels) { + std::string type; + std::deque candidates; + std::set visited; + + candidates.push_back(stream_switch_node); + while (!candidates.empty()) { + NodePtr tmp_node = candidates.front(); + candidates.pop_front(); + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + type = out_node->GetType(); + if (type == STREAMMERGE) { + merge_nodes.emplace_back(out_node); + continue; + } + const auto &op = out_node->GetOpDesc(); + if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + group_nodes.emplace(out_node); + } + if (visited.count(out_node) > 0) { + continue; + } + candidates.push_back(out_node); + visited.insert(out_node); + std::string stream_label; + if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { + stream_labels.insert(stream_label); + } + } + } +} + +Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set &group_nodes, + const std::vector &merge_nodes, + const NodePtr &cast_node, const NodePtr &switch_node, + std::map, NodePtr>> &node_2_switch_merge) { + for (const auto &group_node : group_nodes) { + auto itr = node_2_switch_merge.find(group_node); + if (itr != node_2_switch_merge.end()) { + auto &tmp = itr->second; + auto &switch_set = tmp.first; + const auto &merge_node = tmp.second; + GELOGD("Find group node: %s in switch %s and merge %s.", + group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str()); + if (merge_node != merge_nodes.back()) { + GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid", + merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s," + "graph's structure is invalid", + merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str()); + return FAILED; + } + switch_set.insert(cast_node); + } else { + node_2_switch_merge.emplace(group_node, + std::make_pair(std::set{cast_node}, merge_nodes.back())); + } + } + return SUCCESS; +} + +Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node, + const std::map, NodePtr>> &node_2_switch_merge) { + auto pre_itr = node_2_switch_merge.find(pre_node); + auto cur_itr = node_2_switch_merge.find(cur_node); + if (pre_itr != node_2_switch_merge.end()) { + if (cur_itr != node_2_switch_merge.end()) { + const auto &pre_set = pre_itr->second.first; + const auto &cur_set = cur_itr->second.first; + if (!HasSameSwitch(pre_set, cur_set)) { + pre_node = pre_itr->second.second; + for (const auto &switch_node : cur_itr->second.first) { + if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { + GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + pre_node->GetName().c_str(), switch_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + pre_node->GetName().c_str(), switch_node->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; + } else { + pre_node = pre_itr->second.second; + return AddCtrlEdge(pre_node, cur_node); + } + } else { + if (cur_itr != node_2_switch_merge.end()) { + for (const auto &switch_node : cur_itr->second.first) { + int64_t pre_id = pre_node->GetOpDesc()->GetId(); + int64_t switch_id = switch_node->GetOpDesc()->GetId(); + // avoid ring + if (pre_id > switch_id) { + auto merge_node = cur_itr->second.second; + if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { + GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + pre_node->GetName().c_str(), switch_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + pre_node->GetName().c_str(), switch_node->GetName().c_str()); + return FAILED; + } + } else { + if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { + GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + pre_node->GetName().c_str(), switch_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + pre_node->GetName().c_str(), switch_node->GetName().c_str()); + return FAILED; + } + } + } + } else { + return AddCtrlEdge(pre_node, cur_node); + } + } + return SUCCESS; +} + +bool ParallelGroupPass::HasSameSwitch(const std::set &switch_set1, const std::set &switch_set2) { + for (const auto &node1 : switch_set1) { + auto itr = switch_set2.find(node1); + if (itr != switch_set2.end()) { + return true; + } + } + return false; +} + +bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) { + return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG); +} + +bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { + int64_t stream_switch_type = -1; + return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && + stream_switch_type == kLoopType); +} +} // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h new file mode 100644 index 00000000..9b895598 --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H +#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H + +#include +#include +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class ParallelGroupPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + private: + Status ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth, std::unordered_set ¶llel_group); + + Status AddCtrlEdge(NodePtr pre_node, NodePtr cur_node); + + Status ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node, + const std::map, NodePtr>> &node_2_switch_merge); + + bool HasSameSwitch(const std::set &a, const std::set &b); + + Status ProcessGroupNodeInSwitch(ComputeGraphPtr graph, + std::map, NodePtr>> &node_2_switch_merge); + + void FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set &group_nodes, + std::vector &merge_nodes, std::set &stream_labels); + + Status MappingNodeToSwitchAndMerge(const std::set &group_set, const std::vector &merge_vec, + const NodePtr &cast_node, const NodePtr &switch_node, + std::map, NodePtr>> &node_2_switch_merge); + + bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); + bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 392968e7..8cc90eb1 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -307,6 +307,13 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & hccl_group_id.c_str()); } + int64_t switch_type; + if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, switch_type)) { + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, switch_type); + GELOGD("Set attr ATTR_NAME_STREAM_SWITCH_TYPE for Stream_Switch %s, value is %ld.", node_name.c_str(), + switch_type); + } + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { GELOGE(INTERNAL_ERROR, "set int failed"); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 80636a20..8d63dcce 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -273,6 +273,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" @@ -518,6 +519,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" @@ -695,6 +697,7 @@ set(PASS_TEST_FILES "graph/passes/multi_batch_clone_pass_unittest.cc" "graph/passes/replace_with_empty_const_pass_unittest.cc" "graph/passes/transpose_transdata_pass_unittest.cc" + "graph/passes/parallel_group_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc index 5b87939f..218bfd0d 100644 --- a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc @@ -32,6 +32,7 @@ #include "graph/compute_graph.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" using namespace std; @@ -153,6 +154,22 @@ class UtestLogicalStreamAllocator : public testing::Test { return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); } + SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine, + const string &stream_label = "", + std::string group_name = "1") { + ComputeGraphPtr compute_graph = make_shared(name); + OpDescPtr op_desc = std::make_shared("relu", "Relu"); + op_desc->AddInputDesc(GeTensorDesc()); + op_desc->AddOutputDesc(GeTensorDesc()); + AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name); + compute_graph->AddNode(op_desc); + + SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label); + AddPlaceHolderAndEnd(subgraph, 1, 1); + + return subgraph; + } + void LinkSubGraph(SubGraphInfoPtr subgraph1, const string &end_name, SubGraphInfoPtr subgraph2, const string &placeholder_name) { NodePtr end_node = subgraph1->GetSubGraph()->FindNode(end_name); @@ -878,4 +895,30 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { EXPECT_EQ(ret, NOT_CHANGED); } +TEST_F(UtestLogicalStreamAllocator, test_parallel_group) { + SubGraphInfoPtr data = CreateDataSubgraph(); + SubGraphInfoPtr subgraph1 = CreateParallelGroupSubgraphWithName("graph1", "engine1", ""); + SubGraphInfoPtr subgraph2 = CreateParallelGroupSubgraphWithName("graph2", "engine2", "", "2"); + SubGraphInfoPtr subgraph3 = CreateParallelGroupSubgraphWithName("graph3", "engine3", "", "3"); + SubGraphInfoPtr subgraph4 = CreateParallelGroupSubgraphWithName("graph4", "engine4", "", "4"); + LinkSubGraph(data, "end", subgraph1, "placeholder"); + LinkSubGraph(subgraph1, "end", subgraph2, "placeholder"); + LinkSubGraph(subgraph2, "end", subgraph3, "placeholder"); + LinkSubGraph(subgraph3, "end", subgraph4, "placeholder"); + + EngineConfPtr conf1 = make_shared(); + conf1->id = subgraph1->GetEngineName(); + EngineConfPtr conf2 = make_shared(); + conf2->id = subgraph2->GetEngineName(); + conf2->attach = false; + EngineConfPtr conf3 = make_shared(); + conf3->id = subgraph3->GetEngineName(); + conf3->attach = false; + EngineConfPtr conf4 = make_shared(); + conf4->id = subgraph4->GetEngineName(); + + Status status = AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4}, {conf1, conf2, conf3, conf4}); + EXPECT_EQ(status, ge::SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc new file mode 100644 index 00000000..d5b1db41 --- /dev/null +++ b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc @@ -0,0 +1,304 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define private public + +#include "common/ge_inner_error_codes.h" +#include "inc/pass_manager.h" +#include "utils/graph_utils.h" +#include "graph/passes/parallel_group_pass.h" +#undef private + +namespace ge { +namespace { + +class UtestGraphPassesParallelGgroupPass : public testing::Test { + protected: + UtestGraphPassesParallelGgroupPass() { + graph_ = std::make_shared("test"); + sub_graph_ = std::make_shared("test_subgraph"); + vector shape_vec{1, 1, 1, 1}; + GeShape shape = GeShape(shape_vec); + default_tensor_desc_ = std::make_shared(); + default_tensor_desc_->SetShape(shape); + default_tensor_desc_->SetFormat(FORMAT_NCHW); + default_tensor_desc_->SetDataType(DT_FLOAT); + } + + NodePtr NewNode(const std::string &name, const std::string &type, + int input_cnt, int output_cnt, bool isSubgraph = false) { + OpDescPtr op_desc = std::make_shared(name, type); + for (int i = 0; i < input_cnt; ++i) { + op_desc->AddInputDesc(default_tensor_desc_->Clone()); + } + + for (int i = 0; i < output_cnt; ++i) { + op_desc->AddOutputDesc(default_tensor_desc_->Clone()); + } + NodePtr node = nullptr; + if (isSubgraph) { + node = sub_graph_->AddNode(op_desc); + (void)node->SetOwnerComputeGraph(sub_graph_); + } else { + node = graph_->AddNode(op_desc); + (void)node->SetOwnerComputeGraph(graph_); + } + + return node; + } + + void BuildDefaultGraph() { + /// input + /// \ + /// sqrt pred + /// \ / + /// cast + /// / \ + /// switch_t switch_f + /// | | + /// F T + /// | | + /// Merge + /// | + /// relu + /// | + /// sqrt1 + input_node_ = NewNode("input", RELU, 0, 1); + sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); + pred_node_ = NewNode("pred", GREATER, 2, 1); + cast_node_ = NewNode("cast", CAST, 2, 2); + AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + + switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + output_false_node_ = NewNode("false_output", RELU, 1, 1); + AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + output_true_node_ = NewNode("true_output", RELU, 1, 1); + AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); + relu_node_ = NewNode("relu", RELU, 1, 1); + sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); + + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + + void BuildDefaultGraph1() { + /// input + /// \ + /// sqrt pred + /// \ / + /// Switch + /// | | + /// ----F T---- + /// \ | / \ + /// \ Merge1 Merge2 + /// \_________| + input_node_ = NewNode("input", RELU, 0, 1); + AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + pred_node_ = NewNode("pred", GREATER, 2, 1); + sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); + cast_node_ = NewNode("cast", CAST, 2, 2); + + switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + output_false_node_ = NewNode("false_output", RELU, 1, 2); + AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + output_true_node_ = NewNode("true_output", RELU, 1, 2); + AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); + merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); + + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + + + void BuildDefaultGraph2() { + /// input input1 + /// \ \ + /// sqrt pred sqrt1 pred1 + /// \ / \ / + /// Switch Switch1 + /// | | _______| + /// | | / + /// ____F T____ + /// \ | / \ + /// \ Merge1 Merge2 + /// \__________| + input_node_ = NewNode("input", RELU, 0, 2); + input_node1_ = NewNode("input_1", RELU, 0, 2); + sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); + pred_node_ = NewNode("pred", GREATER, 2, 1); + sqrt_node1_ = NewNode("sqrt_1", SQRT, 1, 1); + pred_node1_ = NewNode("pred_1", LESS, 2, 1); + cast_node_ = NewNode("cast", CAST, 2, 2); + cast_node1_ = NewNode("cast_1", CAST, 2, 2); + AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + AttrUtils::SetStr(input_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2"); + + switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + switch_node1_t = NewNode("switch1_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node1_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node1_f = NewNode("switch1_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node1_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + output_false_node_ = NewNode("false_output", RELU, 2, 2); + AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + output_true_node_ = NewNode("true_output", RELU, 2, 2); + AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2"); + merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); + merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(input_node1_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(0), switch_node1_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(1), switch_node1_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node1_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_node1_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(1)); + + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); + + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + + ComputeGraphPtr graph_; + ComputeGraphPtr sub_graph_; + GeTensorDescPtr default_tensor_desc_; + ParallelGroupPass pass_; + NodePtr pred_node_; + NodePtr pred_node1_; + NodePtr cast_node_; + NodePtr cast_node1_; + NodePtr sqrt_node_; + NodePtr sqrt_node1_; + NodePtr input_node_; + NodePtr input_node1_; + NodePtr switch_node_t; + NodePtr switch_node_f; + NodePtr switch_node1_t; + NodePtr switch_node1_f; + NodePtr output_false_node_; + NodePtr output_true_node_; + NodePtr merge_node_; + NodePtr merge_node1_; + NodePtr relu_node_; +}; + +TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) { + ComputeGraphPtr graph = nullptr; + auto ret = pass_.Run(graph); + EXPECT_EQ(ret, PARAM_INVALID); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) { + BuildDefaultGraph(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); + EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); + EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor())); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) { + BuildDefaultGraph1(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { + BuildDefaultGraph2(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); + EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { + BuildDefaultGraph1(); + NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); + NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true); + NodePtr add = NewNode("add", ADD, 2, 1, true); + AttrUtils::SetStr(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + AttrUtils::SetStr(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + + sub_graph_->SetParentNode(input_node_); + sub_graph_->SetParentGraph(graph_); + auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName()); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName()); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = pass_.Run(sub_graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} + +} // namespace +} // namespace ge