@@ -320,6 +320,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/variable_ref_useless_control_out_delete_pass.cc" | "graph/passes/variable_ref_useless_control_out_delete_pass.cc" | ||||
"graph/passes/end_of_sequence_add_control_pass.cc" | "graph/passes/end_of_sequence_add_control_pass.cc" | ||||
"graph/passes/memcpy_addr_async_pass.cc" | "graph/passes/memcpy_addr_async_pass.cc" | ||||
"graph/passes/parallel_group_pass.cc" | |||||
"graph/passes/set_input_output_offset_pass.cc" | "graph/passes/set_input_output_offset_pass.cc" | ||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
@@ -607,6 +608,7 @@ set(INFER_SRC_LIST | |||||
"graph/passes/hccl_group_pass.cc" | "graph/passes/hccl_group_pass.cc" | ||||
"graph/passes/memcpy_addr_async_pass.cc" | "graph/passes/memcpy_addr_async_pass.cc" | ||||
"graph/passes/set_input_output_offset_pass.cc" | "graph/passes/set_input_output_offset_pass.cc" | ||||
"graph/passes/parallel_group_pass.cc" | |||||
"graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
"graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
"graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
@@ -376,6 +376,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status UpdateForParallelGroupPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
std::map<int, vector<OpDescPtr>> stream_op_map; | |||||
for (const SubgraphPtr &subgraph : subgraphs) { | |||||
auto compute_graph = subgraph->subgraph_info.GetSubGraph(); | |||||
for (const NodePtr &node : compute_graph->GetDirectNode()) { | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { | |||||
int64_t op_desc_stream_id = op_desc->GetStreamId(); | |||||
stream_op_map[op_desc_stream_id].push_back(op_desc); | |||||
} | |||||
} | |||||
} | |||||
for (const auto &itr : stream_op_map) { | |||||
if (itr.first == kInvalidStream) { | |||||
continue; | |||||
} | |||||
std::map<std::string, int64_t> group_2_stream_id; | |||||
for (const auto &op_desc : itr.second) { | |||||
std::string group_name; | |||||
if (!AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) { | |||||
GELOGE(FAILED, "[GetAttr][OpDesc]Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
const auto &itr = group_2_stream_id.find(group_name); | |||||
int64_t new_stream_id = kInvalidStream; | |||||
int64_t old_stream_id = op_desc->GetStreamId(); | |||||
if (itr != group_2_stream_id.end()) { | |||||
new_stream_id = itr->second; | |||||
} else { | |||||
new_stream_id = context.next_stream++; | |||||
group_2_stream_id[group_name] = new_stream_id; | |||||
} | |||||
op_desc->SetStreamId(new_stream_id); | |||||
GELOGD("Node %s assigned stream %ld from stream %ld.", | |||||
op_desc->GetName().c_str(), new_stream_id, old_stream_id); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { | int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { | ||||
set<int64_t> stream_ids; | set<int64_t> stream_ids; | ||||
@@ -665,6 +707,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec | |||||
passes.emplace_back(MakeShared<IndependentStreamPass>()); | passes.emplace_back(MakeShared<IndependentStreamPass>()); | ||||
passes.emplace_back(MakeShared<AssignByDependencyPass>()); | passes.emplace_back(MakeShared<AssignByDependencyPass>()); | ||||
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); | passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); | ||||
passes.emplace_back(MakeShared<UpdateForParallelGroupPass>()); | |||||
passes.emplace_back(MakeShared<AllReduceParallelPass>()); | passes.emplace_back(MakeShared<AllReduceParallelPass>()); | ||||
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>()); | passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>()); | ||||
} | } | ||||
@@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass { | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | ||||
}; | }; | ||||
// assign stream by parallel group | |||||
class UpdateForParallelGroupPass : public LogicalStreamPass { | |||||
public: | |||||
STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass); | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | |||||
// Update the stream of subgraphs to nodes. | // Update the stream of subgraphs to nodes. | ||||
class UpdateForSkippedEnginePass : public LogicalStreamPass { | class UpdateForSkippedEnginePass : public LogicalStreamPass { | ||||
public: | public: | ||||
@@ -93,6 +93,7 @@ | |||||
#include "graph/passes/global_step_insert_pass.h" | #include "graph/passes/global_step_insert_pass.h" | ||||
#include "graph/passes/memcpy_addr_async_pass.h" | #include "graph/passes/memcpy_addr_async_pass.h" | ||||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | #include "graph/passes/hccl_continuous_memcpy_pass.h" | ||||
#include "graph/passes/parallel_group_pass.h" | |||||
#include "graph/build/label_allocator.h" | #include "graph/build/label_allocator.h" | ||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
@@ -2381,6 +2382,12 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||||
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); | GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); | ||||
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | ||||
// Handle parallel group . | |||||
GE_TIMESTAMP_START(ParallelGroup); | |||||
ParallelGroupPass parallel_group_pass; | |||||
GE_CHK_STATUS_RET(parallel_group_pass.Run(compute_graph), "Handle parallel group failed."); | |||||
GE_TIMESTAMP_END(ParallelGroup, "ParallelGroupPass::Run."); | |||||
// After while sub graph handle, mark all node rw type | // After while sub graph handle, mark all node rw type | ||||
auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); | auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); | ||||
if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
@@ -22,6 +22,10 @@ | |||||
using std::string; | using std::string; | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const int64_t kLoopType = 1; | |||||
} | |||||
Status NextIterationPass::Run(ComputeGraphPtr graph) { | Status NextIterationPass::Run(ComputeGraphPtr graph) { | ||||
GELOGD("NextIterationPass Enter"); | GELOGD("NextIterationPass Enter"); | ||||
/// Enter-----------+ | /// Enter-----------+ | ||||
@@ -121,7 +125,10 @@ Status NextIterationPass::FindWhileGroups() { | |||||
if (switch_node == nullptr) { | if (switch_node == nullptr) { | ||||
continue; | continue; | ||||
} | } | ||||
if (!AttrUtils::SetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, kLoopType)) { | |||||
GELOGE(INTERNAL_ERROR, "set int failed"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
NodePtr loop_cond = nullptr; | NodePtr loop_cond = nullptr; | ||||
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { | if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | ||||
@@ -0,0 +1,354 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/parallel_group_pass.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "common/ge/ge_util.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
namespace ge { | |||||
namespace { | |||||
const int32_t kMaxRecursionDepth = 10; | |||||
const int64_t kLoopType = 1; | |||||
} | |||||
Status ParallelGroupPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("ParallelGroupPass running"); | |||||
if (graph == nullptr) { | |||||
GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass."); | |||||
REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (graph->GetParentGraph() != nullptr) { | |||||
GELOGD("Current graph %s is a subgraph, this pass only support root graph.", | |||||
graph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
if (graph->TopologicalSorting() != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.", | |||||
graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
std::unordered_set<std::string> parallel_groups; | |||||
int depth = 0; | |||||
if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (graph->TopologicalSorting() != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.", | |||||
graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth, | |||||
std::unordered_set<std::string> ¶llel_groups) { | |||||
if (depth >= kMaxRecursionDepth) { | |||||
GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth); | |||||
REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth); | |||||
return FAILED; | |||||
} | |||||
std::map<std::string, vector<NodePtr>> group_nodes; | |||||
auto candidates = graph->GetDirectNode(); | |||||
auto root_graph = GraphUtils::FindRootGraph(graph); | |||||
for (const auto &node : candidates) { | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | |||||
} | |||||
std::string group_name; | |||||
if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) { | |||||
group_nodes[group_name].push_back(node); | |||||
parallel_groups.insert(group_name); | |||||
GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str()); | |||||
} | |||||
const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); | |||||
GE_CHECK_NOTNULL(root_graph); | |||||
for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { | |||||
const auto &sub_graph = root_graph->GetSubgraph(*name_iter); | |||||
GE_CHECK_NOTNULL(sub_graph); | |||||
// if the pass add control edge for known and unknown graph, then the known graph will become unknown graph | |||||
// the order between known and unknown graph is guaranteed by dynamic shape executor | |||||
// so the parallel group pass do nothing for unknown graph | |||||
if (sub_graph->GetGraphUnknownFlag()) { | |||||
continue; | |||||
} | |||||
std::unordered_set<std::string> sub_parallel_groups; | |||||
auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
for (const auto &sub_parallel_group : sub_parallel_groups) { | |||||
parallel_groups.insert(sub_parallel_group); | |||||
group_nodes[sub_parallel_group].emplace_back(node); | |||||
} | |||||
} | |||||
} | |||||
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge; | |||||
if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) { | |||||
GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
for (const auto &itr : group_nodes) { | |||||
const auto &nodes = itr.second; | |||||
if (nodes.empty()) { | |||||
continue; | |||||
} | |||||
NodePtr pre_node = nodes[0]; | |||||
NodePtr cur_node = nullptr; | |||||
for (std::size_t i = 1; i < nodes.size(); i++) { | |||||
cur_node = nodes[i]; | |||||
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), | |||||
cur_node->GetName().c_str()); | |||||
if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) { | |||||
GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.", | |||||
pre_node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
pre_node = cur_node; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { | |||||
if (pre_node == cur_node) { | |||||
GELOGD("Pre_node and cur_node are same, no need add anchor"); | |||||
return SUCCESS; | |||||
} | |||||
auto in_nodes = cur_node->GetInAllNodes(); | |||||
for (const auto &node : in_nodes) { | |||||
if (pre_node == node) { | |||||
GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(), | |||||
cur_node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), | |||||
cur_node->GetName().c_str()); | |||||
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), | |||||
cur_node->GetInControlAnchor()); | |||||
} | |||||
Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, | |||||
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) { | |||||
std::string type; | |||||
auto direct_nodes = graph->GetDirectNode(); | |||||
for (const auto &node : direct_nodes) { | |||||
type = node->GetType(); | |||||
if (type != STREAMSWITCH) { | |||||
continue; | |||||
} | |||||
if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) || | |||||
IsWhileStreamSwitch(node->GetOpDesc())) { | |||||
continue; | |||||
} | |||||
std::vector<NodePtr> merge_nodes; | |||||
std::set<NodePtr> group_nodes; | |||||
std::set<std::string> stream_labels; | |||||
FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels); | |||||
if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) { | |||||
GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s," | |||||
"merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(), | |||||
merge_nodes.size(), stream_labels.size(), graph->GetName().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s," | |||||
"merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(), | |||||
merge_nodes.size(), stream_labels.size(), graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
std::sort(merge_nodes.begin(), merge_nodes.end(), | |||||
[] (NodePtr a, NodePtr b) -> bool { | |||||
return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId()); | |||||
}); | |||||
NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0); | |||||
GE_CHECK_NOTNULL(cast_node); | |||||
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, | |||||
cast_node, node, | |||||
node_2_switch_merge) != SUCCESS) { | |||||
GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", | |||||
graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes, | |||||
std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels) { | |||||
std::string type; | |||||
std::deque<NodePtr> candidates; | |||||
std::set<NodePtr> visited; | |||||
candidates.push_back(stream_switch_node); | |||||
while (!candidates.empty()) { | |||||
NodePtr tmp_node = candidates.front(); | |||||
candidates.pop_front(); | |||||
for (const auto &out_node : tmp_node->GetOutAllNodes()) { | |||||
type = out_node->GetType(); | |||||
if (type == STREAMMERGE) { | |||||
merge_nodes.emplace_back(out_node); | |||||
continue; | |||||
} | |||||
const auto &op = out_node->GetOpDesc(); | |||||
if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { | |||||
group_nodes.emplace(out_node); | |||||
} | |||||
if (visited.count(out_node) > 0) { | |||||
continue; | |||||
} | |||||
candidates.push_back(out_node); | |||||
visited.insert(out_node); | |||||
std::string stream_label; | |||||
if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { | |||||
stream_labels.insert(stream_label); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes, | |||||
const std::vector<NodePtr> &merge_nodes, | |||||
const NodePtr &cast_node, const NodePtr &switch_node, | |||||
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) { | |||||
for (const auto &group_node : group_nodes) { | |||||
auto itr = node_2_switch_merge.find(group_node); | |||||
if (itr != node_2_switch_merge.end()) { | |||||
auto &tmp = itr->second; | |||||
auto &switch_set = tmp.first; | |||||
const auto &merge_node = tmp.second; | |||||
GELOGD("Find group node: %s in switch %s and merge %s.", | |||||
group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str()); | |||||
if (merge_node != merge_nodes.back()) { | |||||
GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid", | |||||
merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s," | |||||
"graph's structure is invalid", | |||||
merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
switch_set.insert(cast_node); | |||||
} else { | |||||
node_2_switch_merge.emplace(group_node, | |||||
std::make_pair(std::set<NodePtr>{cast_node}, merge_nodes.back())); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node, | |||||
const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) { | |||||
auto pre_itr = node_2_switch_merge.find(pre_node); | |||||
auto cur_itr = node_2_switch_merge.find(cur_node); | |||||
if (pre_itr != node_2_switch_merge.end()) { | |||||
if (cur_itr != node_2_switch_merge.end()) { | |||||
const auto &pre_set = pre_itr->second.first; | |||||
const auto &cur_set = cur_itr->second.first; | |||||
if (!HasSameSwitch(pre_set, cur_set)) { | |||||
pre_node = pre_itr->second.second; | |||||
for (const auto &switch_node : cur_itr->second.first) { | |||||
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { | |||||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} else { | |||||
pre_node = pre_itr->second.second; | |||||
return AddCtrlEdge(pre_node, cur_node); | |||||
} | |||||
} else { | |||||
if (cur_itr != node_2_switch_merge.end()) { | |||||
for (const auto &switch_node : cur_itr->second.first) { | |||||
int64_t pre_id = pre_node->GetOpDesc()->GetId(); | |||||
int64_t switch_id = switch_node->GetOpDesc()->GetId(); | |||||
// avoid ring | |||||
if (pre_id > switch_id) { | |||||
auto merge_node = cur_itr->second.second; | |||||
if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { | |||||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} else { | |||||
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { | |||||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
return AddCtrlEdge(pre_node, cur_node); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) { | |||||
for (const auto &node1 : switch_set1) { | |||||
auto itr = switch_set2.find(node1); | |||||
if (itr != switch_set2.end()) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) { | |||||
return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG); | |||||
} | |||||
bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { | |||||
int64_t stream_switch_type = -1; | |||||
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && | |||||
stream_switch_type == kLoopType); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H | |||||
#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H | |||||
#include <map> | |||||
#include <unordered_set> | |||||
#include "graph/graph.h" | |||||
#include "inc/graph_pass.h" | |||||
namespace ge { | |||||
class ParallelGroupPass : public GraphPass { | |||||
public: | |||||
Status Run(ComputeGraphPtr graph) override; | |||||
private: | |||||
Status ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth, std::unordered_set<std::string> ¶llel_group); | |||||
Status AddCtrlEdge(NodePtr pre_node, NodePtr cur_node); | |||||
Status ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node, | |||||
const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge); | |||||
bool HasSameSwitch(const std::set<NodePtr> &a, const std::set<NodePtr> &b); | |||||
Status ProcessGroupNodeInSwitch(ComputeGraphPtr graph, | |||||
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge); | |||||
void FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes, | |||||
std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels); | |||||
Status MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_set, const std::vector<NodePtr> &merge_vec, | |||||
const NodePtr &cast_node, const NodePtr &switch_node, | |||||
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge); | |||||
bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); | |||||
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H |
@@ -307,6 +307,13 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
hccl_group_id.c_str()); | hccl_group_id.c_str()); | ||||
} | } | ||||
int64_t switch_type; | |||||
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, switch_type)) { | |||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, switch_type); | |||||
GELOGD("Set attr ATTR_NAME_STREAM_SWITCH_TYPE for Stream_Switch %s, value is %ld.", node_name.c_str(), | |||||
switch_type); | |||||
} | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || | if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || | ||||
!AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { | !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { | ||||
GELOGE(INTERNAL_ERROR, "set int failed"); | GELOGE(INTERNAL_ERROR, "set int failed"); | ||||
@@ -273,6 +273,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||||
"${GE_CODE_DIR}/ge/model/ge_model.cc" | "${GE_CODE_DIR}/ge/model/ge_model.cc" | ||||
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" | ||||
@@ -518,6 +519,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/common/transop_util.cc" | "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||||
#"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" | #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | ||||
@@ -695,6 +697,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
"graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
"graph/passes/transpose_transdata_pass_unittest.cc" | "graph/passes/transpose_transdata_pass_unittest.cc" | ||||
"graph/passes/parallel_group_pass_unittest.cc" | |||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -32,6 +32,7 @@ | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
using namespace std; | using namespace std; | ||||
@@ -153,6 +154,22 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); | return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); | ||||
} | } | ||||
SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine, | |||||
const string &stream_label = "", | |||||
std::string group_name = "1") { | |||||
ComputeGraphPtr compute_graph = make_shared<ComputeGraph>(name); | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>("relu", "Relu"); | |||||
op_desc->AddInputDesc(GeTensorDesc()); | |||||
op_desc->AddOutputDesc(GeTensorDesc()); | |||||
AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name); | |||||
compute_graph->AddNode(op_desc); | |||||
SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label); | |||||
AddPlaceHolderAndEnd(subgraph, 1, 1); | |||||
return subgraph; | |||||
} | |||||
void LinkSubGraph(SubGraphInfoPtr subgraph1, const string &end_name, SubGraphInfoPtr subgraph2, | void LinkSubGraph(SubGraphInfoPtr subgraph1, const string &end_name, SubGraphInfoPtr subgraph2, | ||||
const string &placeholder_name) { | const string &placeholder_name) { | ||||
NodePtr end_node = subgraph1->GetSubGraph()->FindNode(end_name); | NodePtr end_node = subgraph1->GetSubGraph()->FindNode(end_name); | ||||
@@ -878,4 +895,30 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { | |||||
EXPECT_EQ(ret, NOT_CHANGED); | EXPECT_EQ(ret, NOT_CHANGED); | ||||
} | } | ||||
TEST_F(UtestLogicalStreamAllocator, test_parallel_group) { | |||||
SubGraphInfoPtr data = CreateDataSubgraph(); | |||||
SubGraphInfoPtr subgraph1 = CreateParallelGroupSubgraphWithName("graph1", "engine1", ""); | |||||
SubGraphInfoPtr subgraph2 = CreateParallelGroupSubgraphWithName("graph2", "engine2", "", "2"); | |||||
SubGraphInfoPtr subgraph3 = CreateParallelGroupSubgraphWithName("graph3", "engine3", "", "3"); | |||||
SubGraphInfoPtr subgraph4 = CreateParallelGroupSubgraphWithName("graph4", "engine4", "", "4"); | |||||
LinkSubGraph(data, "end", subgraph1, "placeholder"); | |||||
LinkSubGraph(subgraph1, "end", subgraph2, "placeholder"); | |||||
LinkSubGraph(subgraph2, "end", subgraph3, "placeholder"); | |||||
LinkSubGraph(subgraph3, "end", subgraph4, "placeholder"); | |||||
EngineConfPtr conf1 = make_shared<EngineConf>(); | |||||
conf1->id = subgraph1->GetEngineName(); | |||||
EngineConfPtr conf2 = make_shared<EngineConf>(); | |||||
conf2->id = subgraph2->GetEngineName(); | |||||
conf2->attach = false; | |||||
EngineConfPtr conf3 = make_shared<EngineConf>(); | |||||
conf3->id = subgraph3->GetEngineName(); | |||||
conf3->attach = false; | |||||
EngineConfPtr conf4 = make_shared<EngineConf>(); | |||||
conf4->id = subgraph4->GetEngineName(); | |||||
Status status = AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4}, {conf1, conf2, conf3, conf4}); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -0,0 +1,304 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <cstdint> | |||||
#include <string> | |||||
#define private public | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "inc/pass_manager.h" | |||||
#include "utils/graph_utils.h" | |||||
#include "graph/passes/parallel_group_pass.h" | |||||
#undef private | |||||
namespace ge { | |||||
namespace { | |||||
class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
protected: | |||||
UtestGraphPassesParallelGgroupPass() { | |||||
graph_ = std::make_shared<ComputeGraph>("test"); | |||||
sub_graph_ = std::make_shared<ComputeGraph>("test_subgraph"); | |||||
vector<int64_t> shape_vec{1, 1, 1, 1}; | |||||
GeShape shape = GeShape(shape_vec); | |||||
default_tensor_desc_ = std::make_shared<GeTensorDesc>(); | |||||
default_tensor_desc_->SetShape(shape); | |||||
default_tensor_desc_->SetFormat(FORMAT_NCHW); | |||||
default_tensor_desc_->SetDataType(DT_FLOAT); | |||||
} | |||||
NodePtr NewNode(const std::string &name, const std::string &type, | |||||
int input_cnt, int output_cnt, bool isSubgraph = false) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (int i = 0; i < input_cnt; ++i) { | |||||
op_desc->AddInputDesc(default_tensor_desc_->Clone()); | |||||
} | |||||
for (int i = 0; i < output_cnt; ++i) { | |||||
op_desc->AddOutputDesc(default_tensor_desc_->Clone()); | |||||
} | |||||
NodePtr node = nullptr; | |||||
if (isSubgraph) { | |||||
node = sub_graph_->AddNode(op_desc); | |||||
(void)node->SetOwnerComputeGraph(sub_graph_); | |||||
} else { | |||||
node = graph_->AddNode(op_desc); | |||||
(void)node->SetOwnerComputeGraph(graph_); | |||||
} | |||||
return node; | |||||
} | |||||
void BuildDefaultGraph() { | |||||
/// input | |||||
/// \ | |||||
/// sqrt pred | |||||
/// \ / | |||||
/// cast | |||||
/// / \ | |||||
/// switch_t switch_f | |||||
/// | | | |||||
/// F T | |||||
/// | | | |||||
/// Merge | |||||
/// | | |||||
/// relu | |||||
/// | | |||||
/// sqrt1 | |||||
input_node_ = NewNode("input", RELU, 0, 1); | |||||
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); | |||||
pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
cast_node_ = NewNode("cast", CAST, 2, 2); | |||||
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); | |||||
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); | |||||
output_false_node_ = NewNode("false_output", RELU, 1, 1); | |||||
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
output_true_node_ = NewNode("true_output", RELU, 1, 1); | |||||
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); | |||||
relu_node_ = NewNode("relu", RELU, 1, 1); | |||||
sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); | |||||
AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); | |||||
output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
} | |||||
void BuildDefaultGraph1() { | |||||
/// input | |||||
/// \ | |||||
/// sqrt pred | |||||
/// \ / | |||||
/// Switch | |||||
/// | | | |||||
/// ----F T---- | |||||
/// \ | / \ | |||||
/// \ Merge1 Merge2 | |||||
/// \_________| | |||||
input_node_ = NewNode("input", RELU, 0, 1); | |||||
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); | |||||
cast_node_ = NewNode("cast", CAST, 2, 2); | |||||
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); | |||||
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); | |||||
output_false_node_ = NewNode("false_output", RELU, 1, 2); | |||||
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
output_true_node_ = NewNode("true_output", RELU, 1, 2); | |||||
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); | |||||
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); | |||||
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); | |||||
output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
} | |||||
void BuildDefaultGraph2() { | |||||
/// input input1 | |||||
/// \ \ | |||||
/// sqrt pred sqrt1 pred1 | |||||
/// \ / \ / | |||||
/// Switch Switch1 | |||||
/// | | _______| | |||||
/// | | / | |||||
/// ____F T____ | |||||
/// \ | / \ | |||||
/// \ Merge1 Merge2 | |||||
/// \__________| | |||||
input_node_ = NewNode("input", RELU, 0, 2); | |||||
input_node1_ = NewNode("input_1", RELU, 0, 2); | |||||
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); | |||||
pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
sqrt_node1_ = NewNode("sqrt_1", SQRT, 1, 1); | |||||
pred_node1_ = NewNode("pred_1", LESS, 2, 1); | |||||
cast_node_ = NewNode("cast", CAST, 2, 2); | |||||
cast_node1_ = NewNode("cast_1", CAST, 2, 2); | |||||
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
AttrUtils::SetStr(input_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2"); | |||||
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); | |||||
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); | |||||
switch_node1_t = NewNode("switch1_t", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node1_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); | |||||
switch_node1_f = NewNode("switch1_f", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node1_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); | |||||
output_false_node_ = NewNode("false_output", RELU, 2, 2); | |||||
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
output_true_node_ = NewNode("true_output", RELU, 2, 2); | |||||
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2"); | |||||
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); | |||||
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); | |||||
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(input_node1_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(0), switch_node1_t->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(1), switch_node1_f->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node1_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(switch_node1_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); | |||||
output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
} | |||||
ComputeGraphPtr graph_; | |||||
ComputeGraphPtr sub_graph_; | |||||
GeTensorDescPtr default_tensor_desc_; | |||||
ParallelGroupPass pass_; | |||||
NodePtr pred_node_; | |||||
NodePtr pred_node1_; | |||||
NodePtr cast_node_; | |||||
NodePtr cast_node1_; | |||||
NodePtr sqrt_node_; | |||||
NodePtr sqrt_node1_; | |||||
NodePtr input_node_; | |||||
NodePtr input_node1_; | |||||
NodePtr switch_node_t; | |||||
NodePtr switch_node_f; | |||||
NodePtr switch_node1_t; | |||||
NodePtr switch_node1_f; | |||||
NodePtr output_false_node_; | |||||
NodePtr output_true_node_; | |||||
NodePtr merge_node_; | |||||
NodePtr merge_node1_; | |||||
NodePtr relu_node_; | |||||
}; | |||||
TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) { | |||||
ComputeGraphPtr graph = nullptr; | |||||
auto ret = pass_.Run(graph); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
} | |||||
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) { | |||||
BuildDefaultGraph(); | |||||
auto ret = pass_.Run(graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); | |||||
EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); | |||||
EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor())); | |||||
} | |||||
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) { | |||||
BuildDefaultGraph1(); | |||||
auto ret = pass_.Run(graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); | |||||
} | |||||
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { | |||||
BuildDefaultGraph2(); | |||||
auto ret = pass_.Run(graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); | |||||
EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); | |||||
} | |||||
TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { | |||||
BuildDefaultGraph1(); | |||||
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); | |||||
NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true); | |||||
NodePtr add = NewNode("add", ADD, 2, 1, true); | |||||
AttrUtils::SetStr(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
AttrUtils::SetStr(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
sub_graph_->SetParentNode(input_node_); | |||||
sub_graph_->SetParentGraph(graph_); | |||||
auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName()); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName()); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
ret = pass_.Run(sub_graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
ret = pass_.Run(graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
} | |||||
} // namespace | |||||
} // namespace ge |