diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 404c0928..6ff9f5d9 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -307,7 +307,7 @@ set(TRAIN_SRC_LIST "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_input_memcpy_pass.cc" "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/mark_branch_force_unknown_pass.cc" + "graph/passes/mark_force_unknown_for_cond_pass.cc" "graph/passes/attach_stream_label_pass.cc" "graph/passes/switch_dead_branch_elimination.cc" "graph/passes/replace_transshape_pass.cc" @@ -585,7 +585,7 @@ set(INFER_SRC_LIST "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_input_memcpy_pass.cc" "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/mark_branch_force_unknown_pass.cc" + "graph/passes/mark_force_unknown_for_cond_pass.cc" "graph/passes/attach_stream_label_pass.cc" "graph/passes/multi_batch_pass.cc" "graph/passes/multi_batch_clone_pass.cc" diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index a71d2ab7..819198b0 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -65,7 +65,7 @@ #include "graph/passes/merge_pass.h" #include "graph/passes/merge_input_memcpy_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h" -#include "graph/passes/mark_branch_force_unknown_pass.h" +#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/passes/multi_batch_pass.h" #include "graph/passes/next_iteration_pass.h" #include "graph/passes/permute_pass.h" @@ -2537,8 +2537,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); - auto mark_force_unknown_pass = new (std::nothrow) MarkBranchForceUnknownPass; - GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkBranchForceUnknownPass", mark_force_unknown_pass)); + auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) GE_CHK_STATUS_RET( diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc similarity index 92% rename from ge/graph/passes/mark_branch_force_unknown_pass.cc rename to ge/graph/passes/mark_force_unknown_for_cond_pass.cc index c4c5d1dd..d0b9af7e 100644 --- a/ge/graph/passes/mark_branch_force_unknown_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "mark_branch_force_unknown_pass.h" +#include "mark_force_unknown_for_cond_pass.h" #include @@ -35,8 +35,8 @@ inline bool IsMergeInLoop(const NodePtr &node) { } } -Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { - GELOGD("MarkBranchForceUnknownPass Enter"); +Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { + GELOGD("MarkForceUnknownForCondPass Enter"); for (const auto &node : graph->GetDirectNode()) { std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); @@ -58,7 +58,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { MarkUnknownForSwitch(node); } - GELOGD("MarkBranchForceUnknownPass Leave"); + GELOGD("MarkForceUnknownForCondPass Leave"); return SUCCESS; } @@ -67,7 +67,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { /// @param [in] merge node /// @return /// -void MarkBranchForceUnknownPass::MarkUnknownForSwitch(const NodePtr &node) { +void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { // Switch --> {Switch --> Merge} --> Merge std::vector switch_group; std::unordered_set nodes_seen; diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h similarity index 74% rename from ge/graph/passes/mark_branch_force_unknown_pass.h rename to ge/graph/passes/mark_force_unknown_for_cond_pass.h index 4b7f6668..65e09394 100644 --- a/ge/graph/passes/mark_branch_force_unknown_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,13 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ -#define GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ +#ifndef GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ +#define GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ #include "inc/graph_pass.h" namespace ge { -class MarkBranchForceUnknownPass : public GraphPass { +class MarkForceUnknownForCondPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); @@ -33,4 +33,4 @@ class MarkBranchForceUnknownPass : public GraphPass { void MarkUnknownForSwitch(const NodePtr &node); }; } // namespace ge -#endif // GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ +#endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 55db48e2..9a5806a7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -239,7 +239,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_branch_force_unknown_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" @@ -704,7 +704,7 @@ set(PASS_TEST_FILES "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" - "graph/passes/mark_branch_force_unknown_pass_unittest.cc" + "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" "graph/passes/replace_with_empty_const_pass_unittest.cc" "graph/passes/link_gen_mask_nodes_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc similarity index 94% rename from tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc rename to tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc index 7f1b05ff..b416d958 100644 --- a/tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #define protected public #define private public -#include "graph/passes/mark_branch_force_unknown_pass.h" +#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" @@ -29,7 +29,7 @@ using namespace std; using namespace testing; namespace ge { -class UtestMarkBranchForceUnknownPass : public testing::Test { +class UtestMarkForceUnknownForCondPass : public testing::Test { protected: void SetUp() {} void TearDown() {} @@ -194,28 +194,28 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { merge = merge1; } -TEST_F(UtestMarkBranchForceUnknownPass, skip_while_loop_merge) { +TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; CreateLoopGraph(graph, merge); AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - MarkBranchForceUnknownPass mark_force_unknown_pass; + MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond } -TEST_F(UtestMarkBranchForceUnknownPass, skip_known_shape_merge) { +TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; CreateCondGraph(graph, merge); - MarkBranchForceUnknownPass mark_force_unknown_pass; + MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge } -TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { +TEST_F(UtestMarkForceUnknownForCondPass, mark_unknown_shape_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; CreateCondGraph(graph, merge); @@ -224,7 +224,7 @@ TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { tensor_desc.SetShape(GeShape({-1})); // Set for unknown. merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); - MarkBranchForceUnknownPass mark_force_unknown_pass; + MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); } } // namespace ge