From e9edaca33f9712acf301d327b22fa29e61c02b79 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 6 May 2021 09:03:20 +0800 Subject: [PATCH 1/3] Fix mark branch force unknown --- ge/CMakeLists.txt | 2 + ge/graph/manager/graph_manager.cc | 5 +- .../passes/mark_branch_force_unknown_pass.cc | 125 ++++++++++ .../passes/mark_branch_force_unknown_pass.h | 36 +++ tests/ut/ge/CMakeLists.txt | 2 + ...mark_branch_force_unknown_pass_unittest.cc | 230 ++++++++++++++++++ 6 files changed, 399 insertions(+), 1 deletion(-) create mode 100644 ge/graph/passes/mark_branch_force_unknown_pass.cc create mode 100644 ge/graph/passes/mark_branch_force_unknown_pass.h create mode 100644 tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index e25c4892..404c0928 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -307,6 +307,7 @@ set(TRAIN_SRC_LIST "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_input_memcpy_pass.cc" "graph/passes/switch_to_stream_switch_pass.cc" + "graph/passes/mark_branch_force_unknown_pass.cc" "graph/passes/attach_stream_label_pass.cc" "graph/passes/switch_dead_branch_elimination.cc" "graph/passes/replace_transshape_pass.cc" @@ -584,6 +585,7 @@ set(INFER_SRC_LIST "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_input_memcpy_pass.cc" "graph/passes/switch_to_stream_switch_pass.cc" + "graph/passes/mark_branch_force_unknown_pass.cc" "graph/passes/attach_stream_label_pass.cc" "graph/passes/multi_batch_pass.cc" "graph/passes/multi_batch_clone_pass.cc" diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 17779161..a71d2ab7 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -65,6 +65,7 @@ #include "graph/passes/merge_pass.h" #include "graph/passes/merge_input_memcpy_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h" +#include "graph/passes/mark_branch_force_unknown_pass.h" #include "graph/passes/multi_batch_pass.h" #include "graph/passes/next_iteration_pass.h" #include "graph/passes/permute_pass.h" @@ -2535,7 +2536,9 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { // the prune pass should between SwitchPass and SwitchToStreamSwitchPass GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); - GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); + auto mark_force_unknown_pass = new (std::nothrow) MarkBranchForceUnknownPass; + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkBranchForceUnknownPass", mark_force_unknown_pass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) GE_CHK_STATUS_RET( diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.cc b/ge/graph/passes/mark_branch_force_unknown_pass.cc new file mode 100644 index 00000000..4b00b24d --- /dev/null +++ b/ge/graph/passes/mark_branch_force_unknown_pass.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mark_branch_force_unknown_pass.h" + +#include + +#include "graph/common/omg_util.h" + +namespace ge { +namespace { +const std::set kMergeOpTypes{ MERGE, REFMERGE }; + +const std::set kSwitchOpTypes{ SWITCH, REFSWITCH }; + +const std::set kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; + +inline bool IsMergeInLoop(const NodePtr &node) { + std::string node_type; + (void)GetOriginalType(node, node_type); + return kLoopMergeInputs.count(node_type) > 0; +} +} + +Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { + GELOGD("MarkBranchForceUnknownPass Enter"); + for (const auto &node : graph->GetDirectNode()) { + std::string node_type; + GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); + if ((node_type != MERGE) && (node_type != REFMERGE)) { + continue; + } + + const auto op_desc = node->GetOpDesc(); + if (!op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) && !IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { + GELOGI("Merge[%s] has known shape, no need check switch", node->GetName().c_str()); + continue; + } + + const auto &all_in_nodes = node->GetInDataNodes(); + if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsMergeInLoop)) { + continue; // LoopCond marked in NextIterationPass. + } + + MarkUnknownForSwitch(node); + } + + GELOGD("MarkBranchForceUnknownPass Leave"); + return SUCCESS; +} + +/// +/// @brief Mark force unknown shape for Switch node +/// @param [in] merge node +/// @return +/// +void MarkBranchForceUnknownPass::MarkUnknownForSwitch(const NodePtr &node) { + // Switch --> {Switch --> Merge} --> Merge + std::vector switch_group; + std::unordered_set nodes_seen; + + std::queue> search_queue({{node, 0}}); + while (!search_queue.empty()) { + const auto dst_node = search_queue.front().first; + const auto dst_span = search_queue.front().second; + search_queue.pop(); + + // Switch --> Identity --> Constant + for (const auto &in_node : dst_node->GetInControlNodes()) { + if (nodes_seen.count(in_node) > 0) { + GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); + continue; + } + nodes_seen.insert(in_node); + + if (in_node->GetType() == IDENTITY) { + GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), + in_node->GetName().c_str(), dst_span); + search_queue.push({in_node, dst_span}); + } + } + + for (const auto &in_node : dst_node->GetInDataNodes()) { + if (nodes_seen.count(in_node) > 0) { + GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); + continue; + } + nodes_seen.insert(in_node); + + std::string node_type; + (void)GetOriginalType(in_node, node_type); + GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), + in_node->GetName().c_str(), dst_span); + if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. + if (dst_span > 0) { + search_queue.push({in_node, dst_span - 1}); + } else { + switch_group.emplace_back(in_node); + } + } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. + search_queue.push({in_node, dst_span + 1}); + } else { + search_queue.push({in_node, dst_span}); + } + } + } + + for (const auto &n : switch_group) { + MarkForceUnknownShape(n, true); + } +} +} // namespace ge diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.h b/ge/graph/passes/mark_branch_force_unknown_pass.h new file mode 100644 index 00000000..4b7f6668 --- /dev/null +++ b/ge/graph/passes/mark_branch_force_unknown_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ +#define GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { +class MarkBranchForceUnknownPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + /// + /// @brief Mark force unknown shape for Switch node + /// @param [in] merge node + /// @return + /// + void MarkUnknownForSwitch(const NodePtr &node); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index c3337487..55db48e2 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -239,6 +239,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_branch_force_unknown_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" @@ -703,6 +704,7 @@ set(PASS_TEST_FILES "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" + "graph/passes/mark_branch_force_unknown_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" "graph/passes/replace_with_empty_const_pass_unittest.cc" "graph/passes/link_gen_mask_nodes_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc new file mode 100644 index 00000000..7f1b05ff --- /dev/null +++ b/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc @@ -0,0 +1,230 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/mark_branch_force_unknown_pass.h" + +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/operator_factory.h" +#include "graph/operator_reg.h" +#include "graph_builder_utils.h" + +using namespace std; +using namespace testing; +namespace ge { +class UtestMarkBranchForceUnknownPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + TensorUtils::SetSize(tensor, 512); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(1024); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(1024); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); + + const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; + op_desc->AddInferFunc(stub_func); + op_desc->AddInferFormatFunc(stub_func); + op_desc->AddVerifierFunc(stub_func); + + return graph.AddNode(op_desc); +} + +static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { +/******************************************************************************* + * Exit Identify + * \ / \. + * \ / \. + * Switch Add + * / | | + * / | | + * / | | + * LoopCond | | + * \ | | + * \ | | + * \ | | + * Less | | + * \ | NextIteration + * \ | | + * \ | | + * Merge <---------| + * | + * | + * Enter + ******************************************************************************/ + auto data1 = CreateNode(*graph, "data", DATA, 1, 1); + auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); + auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); + auto less1 = CreateNode(*graph, "less", LESS, 2, 1); + auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); + auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); + auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); + auto add1 = CreateNode(*graph, "add", ADD, 2, 1); + auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); + auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); + auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); + GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + + GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); + GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); + GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); + GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + merge = merge1; +} + +static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { +/******************************************************************************* + * NetOutput + * | + * | + * Merge + * / \. + * / \. + * / \. + * Add Sub + * | \ | \. + * | \ | \. + * | \ | Const + * | \ | \. + * | \ | Identify + * | \ | | + * Switch Switch Switch Switch + * | | | | | + * | | | | | + * x y Cond z + ******************************************************************************/ + auto data1 = CreateNode(*graph, "data_x", DATA, 1, 1); + auto data2 = CreateNode(*graph, "data_y", DATA, 1, 1); + auto data3 = CreateNode(*graph, "data_z", DATA, 1, 1); + + auto less1 = CreateNode(*graph, "less", LESS, 2, 1); + + auto switch1 = CreateNode(*graph, "switch_x", SWITCH, 2, 2); + auto switch2 = CreateNode(*graph, "switch_y", SWITCH, 2, 2); + auto switch3 = CreateNode(*graph, "switch_z", SWITCH, 2, 2); + auto switch4 = CreateNode(*graph, "switch_i", SWITCH, 2, 2); + + auto add1 = CreateNode(*graph, "add", ADD, 2, 1); + auto sub1 = CreateNode(*graph, "add", SUB, 2, 1); + auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); + auto const1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + + auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); + auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); + GraphUtils::AddEdge(data3->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch4->GetInDataAnchor(0)); + + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch4->GetInDataAnchor(1)); + + GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch2->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), sub1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch4->GetOutDataAnchor(0), ident1->GetInDataAnchor(1)); + GraphUtils::AddEdge(ident1->GetOutControlAnchor(), const1->GetInControlAnchor()); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1)); + + GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + merge = merge1; +} + +TEST_F(UtestMarkBranchForceUnknownPass, skip_while_loop_merge) { + auto graph = std::make_shared("test_graph"); + NodePtr merge; + CreateLoopGraph(graph, merge); + + AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); + + MarkBranchForceUnknownPass mark_force_unknown_pass; + EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond +} + +TEST_F(UtestMarkBranchForceUnknownPass, skip_known_shape_merge) { + auto graph = std::make_shared("test_graph"); + NodePtr merge; + CreateCondGraph(graph, merge); + + MarkBranchForceUnknownPass mark_force_unknown_pass; + EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge +} + + +TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { + auto graph = std::make_shared("test_graph"); + NodePtr merge; + CreateCondGraph(graph, merge); + + auto tensor_desc = merge->GetOpDesc()->GetOutputDesc(0); + tensor_desc.SetShape(GeShape({-1})); // Set for unknown. + merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + + MarkBranchForceUnknownPass mark_force_unknown_pass; + EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); +} +} // namespace ge From 1001a8a8598ed9b9aa9cf4557d6c6d74a893e6a2 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 6 May 2021 09:45:41 +0800 Subject: [PATCH 2/3] revert MergeInputMemcpyPass --- .../passes/mark_branch_force_unknown_pass.cc | 2 +- ge/graph/passes/merge_input_memcpy_pass.cc | 99 ------------------- ge/graph/passes/merge_input_memcpy_pass.h | 15 --- 3 files changed, 1 insertion(+), 115 deletions(-) diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.cc b/ge/graph/passes/mark_branch_force_unknown_pass.cc index 4b00b24d..c4c5d1dd 100644 --- a/ge/graph/passes/mark_branch_force_unknown_pass.cc +++ b/ge/graph/passes/mark_branch_force_unknown_pass.cc @@ -40,7 +40,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); - if ((node_type != MERGE) && (node_type != REFMERGE)) { + if (kMergeOpTypes.count(node_type) == 0) { continue; } diff --git a/ge/graph/passes/merge_input_memcpy_pass.cc b/ge/graph/passes/merge_input_memcpy_pass.cc index ce38a3dd..c4273584 100644 --- a/ge/graph/passes/merge_input_memcpy_pass.cc +++ b/ge/graph/passes/merge_input_memcpy_pass.cc @@ -16,18 +16,11 @@ #include "graph/passes/merge_input_memcpy_pass.h" -#include - #include "common/ge/ge_util.h" #include "ge/ge_api_types.h" #include "graph/common/omg_util.h" namespace ge { -namespace { -const std::set kLoopMergeInputs{ - ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION -}; -} Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { GELOGD("MergeInputMemcpyPass Enter"); std::unordered_map> switch_groups; @@ -41,10 +34,8 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(node->GetOpDesc()); GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), "Merge add memcpy node failed."); - CollectSwitchGroup(node, switch_groups); } - MarkUnknownForSwitch(switch_groups); GELOGD("MergeInputMemcpyPass Leave"); return SUCCESS; } @@ -114,94 +105,4 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph return graph->AddNode(op_desc); } - -/// -/// @brief Mark force unknown shape for Switch node -/// @param [in] merge node -/// @param [out] switch_groups -/// @return -/// -void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node, - std::unordered_map> &switch_groups) { - const auto &op_desc = node->GetOpDesc(); - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - const auto &src_out_anchor = in_anchor->GetPeerOutAnchor(); - if (src_out_anchor == nullptr) { - continue; - } - - std::string node_type; - GetOriginalType(src_out_anchor->GetOwnerNode(), node_type); - if (kLoopMergeInputs.count(node_type) > 0) { - return; - } - } - - // Switch --> {Switch --> Merge} --> Merge - std::queue> search_queue; - search_queue.push({node, 0}); - std::vector &switch_group = switch_groups[node]; - while (!search_queue.empty()) { - const auto dst_node = search_queue.front().first; - const auto dst_span = search_queue.front().second; - search_queue.pop(); - - // Switch --> Identity --> Constant - for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) { - if (in_ctrl_node->GetType() == IDENTITY) { - GELOGD("Travel node: %s, In control: %s, span is: %u", - dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span); - search_queue.push({in_ctrl_node, dst_span}); - } - } - - for (const auto &in_data_node : dst_node->GetInDataNodes()) { - std::string node_type; - GetOriginalType(in_data_node, node_type); - GELOGD("Travel node: %s, %s node: %s, span is: %u", - dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span); - if (node_type == SWITCH || node_type == REFSWITCH) { - if (dst_span > 0) { - search_queue.push({in_data_node, dst_span - 1}); - } else { - switch_group.emplace_back(in_data_node); - } - } else if (node_type == MERGE || node_type == REFMERGE) { - search_queue.push({in_data_node, dst_span + 1}); - } else { - search_queue.push({in_data_node, dst_span}); - } - } - } - - if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { - GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size()); - MarkForceUnknownShape(node, true); - for (const auto &n : switch_group) { - MarkForceUnknownShape(n, true); - } - } -} - -void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map> &switch_groups) { - std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); - }; - - for (const auto &item : switch_groups) { - const auto &node = item.first; - if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { - continue; - } - - const std::vector &switch_group = item.second; - if (std::any_of(switch_group.begin(), switch_group.end(), callback)) { - GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size()); - MarkForceUnknownShape(node, true); - for (const auto &n : switch_group) { - MarkForceUnknownShape(n, true); - } - } - } -} } // namespace ge diff --git a/ge/graph/passes/merge_input_memcpy_pass.h b/ge/graph/passes/merge_input_memcpy_pass.h index 2c7636ea..b8c6f0b8 100644 --- a/ge/graph/passes/merge_input_memcpy_pass.h +++ b/ge/graph/passes/merge_input_memcpy_pass.h @@ -44,21 +44,6 @@ class MergeInputMemcpyPass : public GraphPass { /// NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); - - /// - /// @brief Mark force unknown shape for Switch node - /// @param [in] merge node - /// @param [out] switch_groups - /// @return - /// - void CollectSwitchGroup(const NodePtr &node, std::unordered_map> &switch_groups); - - /// - /// @brief Mark force unknown shape for Switch node - /// @param [in] switch_groups - /// @return - /// - void MarkUnknownForSwitch(const std::unordered_map> &switch_groups); }; } // namespace ge #endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ From a4aae9c691b21c9f6b7dae5eb6dd261022dd59f7 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 6 May 2021 10:39:57 +0800 Subject: [PATCH 3/3] Rename mark_branch_force_unknown_pass --- ge/CMakeLists.txt | 4 ++-- ge/graph/manager/graph_manager.cc | 6 +++--- ....cc => mark_force_unknown_for_cond_pass.cc} | 12 ++++++------ ...ss.h => mark_force_unknown_for_cond_pass.h} | 10 +++++----- tests/ut/ge/CMakeLists.txt | 4 ++-- ...rk_force_unknown_for_cond_pass_unittest.cc} | 18 +++++++++--------- 6 files changed, 27 insertions(+), 27 deletions(-) rename ge/graph/passes/{mark_branch_force_unknown_pass.cc => mark_force_unknown_for_cond_pass.cc} (92%) rename ge/graph/passes/{mark_branch_force_unknown_pass.h => mark_force_unknown_for_cond_pass.h} (74%) rename tests/ut/ge/graph/passes/{mark_branch_force_unknown_pass_unittest.cc => mark_force_unknown_for_cond_pass_unittest.cc} (94%) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 404c0928..6ff9f5d9 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -307,7 +307,7 @@ set(TRAIN_SRC_LIST "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_input_memcpy_pass.cc" "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/mark_branch_force_unknown_pass.cc" + "graph/passes/mark_force_unknown_for_cond_pass.cc" "graph/passes/attach_stream_label_pass.cc" "graph/passes/switch_dead_branch_elimination.cc" "graph/passes/replace_transshape_pass.cc" @@ -585,7 +585,7 @@ set(INFER_SRC_LIST "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_input_memcpy_pass.cc" "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/mark_branch_force_unknown_pass.cc" + "graph/passes/mark_force_unknown_for_cond_pass.cc" "graph/passes/attach_stream_label_pass.cc" "graph/passes/multi_batch_pass.cc" "graph/passes/multi_batch_clone_pass.cc" diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index a71d2ab7..819198b0 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -65,7 +65,7 @@ #include "graph/passes/merge_pass.h" #include "graph/passes/merge_input_memcpy_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h" -#include "graph/passes/mark_branch_force_unknown_pass.h" +#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/passes/multi_batch_pass.h" #include "graph/passes/next_iteration_pass.h" #include "graph/passes/permute_pass.h" @@ -2537,8 +2537,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); - auto mark_force_unknown_pass = new (std::nothrow) MarkBranchForceUnknownPass; - GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkBranchForceUnknownPass", mark_force_unknown_pass)); + auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) GE_CHK_STATUS_RET( diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc similarity index 92% rename from ge/graph/passes/mark_branch_force_unknown_pass.cc rename to ge/graph/passes/mark_force_unknown_for_cond_pass.cc index c4c5d1dd..d0b9af7e 100644 --- a/ge/graph/passes/mark_branch_force_unknown_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "mark_branch_force_unknown_pass.h" +#include "mark_force_unknown_for_cond_pass.h" #include @@ -35,8 +35,8 @@ inline bool IsMergeInLoop(const NodePtr &node) { } } -Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { - GELOGD("MarkBranchForceUnknownPass Enter"); +Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { + GELOGD("MarkForceUnknownForCondPass Enter"); for (const auto &node : graph->GetDirectNode()) { std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); @@ -58,7 +58,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { MarkUnknownForSwitch(node); } - GELOGD("MarkBranchForceUnknownPass Leave"); + GELOGD("MarkForceUnknownForCondPass Leave"); return SUCCESS; } @@ -67,7 +67,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { /// @param [in] merge node /// @return /// -void MarkBranchForceUnknownPass::MarkUnknownForSwitch(const NodePtr &node) { +void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { // Switch --> {Switch --> Merge} --> Merge std::vector switch_group; std::unordered_set nodes_seen; diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h similarity index 74% rename from ge/graph/passes/mark_branch_force_unknown_pass.h rename to ge/graph/passes/mark_force_unknown_for_cond_pass.h index 4b7f6668..65e09394 100644 --- a/ge/graph/passes/mark_branch_force_unknown_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,13 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ -#define GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ +#ifndef GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ +#define GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ #include "inc/graph_pass.h" namespace ge { -class MarkBranchForceUnknownPass : public GraphPass { +class MarkForceUnknownForCondPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); @@ -33,4 +33,4 @@ class MarkBranchForceUnknownPass : public GraphPass { void MarkUnknownForSwitch(const NodePtr &node); }; } // namespace ge -#endif // GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ +#endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 55db48e2..9a5806a7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -239,7 +239,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_branch_force_unknown_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" @@ -704,7 +704,7 @@ set(PASS_TEST_FILES "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" - "graph/passes/mark_branch_force_unknown_pass_unittest.cc" + "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" "graph/passes/replace_with_empty_const_pass_unittest.cc" "graph/passes/link_gen_mask_nodes_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc similarity index 94% rename from tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc rename to tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc index 7f1b05ff..b416d958 100644 --- a/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #define protected public #define private public -#include "graph/passes/mark_branch_force_unknown_pass.h" +#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" @@ -29,7 +29,7 @@ using namespace std; using namespace testing; namespace ge { -class UtestMarkBranchForceUnknownPass : public testing::Test { +class UtestMarkForceUnknownForCondPass : public testing::Test { protected: void SetUp() {} void TearDown() {} @@ -194,28 +194,28 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { merge = merge1; } -TEST_F(UtestMarkBranchForceUnknownPass, skip_while_loop_merge) { +TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; CreateLoopGraph(graph, merge); AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - MarkBranchForceUnknownPass mark_force_unknown_pass; + MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond } -TEST_F(UtestMarkBranchForceUnknownPass, skip_known_shape_merge) { +TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; CreateCondGraph(graph, merge); - MarkBranchForceUnknownPass mark_force_unknown_pass; + MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge } -TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { +TEST_F(UtestMarkForceUnknownForCondPass, mark_unknown_shape_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; CreateCondGraph(graph, merge); @@ -224,7 +224,7 @@ TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { tensor_desc.SetShape(GeShape({-1})); // Set for unknown. merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); - MarkBranchForceUnknownPass mark_force_unknown_pass; + MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); } } // namespace ge