@@ -307,7 +307,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/merge_to_stream_merge_pass.cc" | "graph/passes/merge_to_stream_merge_pass.cc" | ||||
"graph/passes/merge_input_memcpy_pass.cc" | "graph/passes/merge_input_memcpy_pass.cc" | ||||
"graph/passes/switch_to_stream_switch_pass.cc" | "graph/passes/switch_to_stream_switch_pass.cc" | ||||
"graph/passes/mark_branch_force_unknown_pass.cc" | |||||
"graph/passes/mark_force_unknown_for_cond_pass.cc" | |||||
"graph/passes/attach_stream_label_pass.cc" | "graph/passes/attach_stream_label_pass.cc" | ||||
"graph/passes/switch_dead_branch_elimination.cc" | "graph/passes/switch_dead_branch_elimination.cc" | ||||
"graph/passes/replace_transshape_pass.cc" | "graph/passes/replace_transshape_pass.cc" | ||||
@@ -585,7 +585,7 @@ set(INFER_SRC_LIST | |||||
"graph/passes/merge_to_stream_merge_pass.cc" | "graph/passes/merge_to_stream_merge_pass.cc" | ||||
"graph/passes/merge_input_memcpy_pass.cc" | "graph/passes/merge_input_memcpy_pass.cc" | ||||
"graph/passes/switch_to_stream_switch_pass.cc" | "graph/passes/switch_to_stream_switch_pass.cc" | ||||
"graph/passes/mark_branch_force_unknown_pass.cc" | |||||
"graph/passes/mark_force_unknown_for_cond_pass.cc" | |||||
"graph/passes/attach_stream_label_pass.cc" | "graph/passes/attach_stream_label_pass.cc" | ||||
"graph/passes/multi_batch_pass.cc" | "graph/passes/multi_batch_pass.cc" | ||||
"graph/passes/multi_batch_clone_pass.cc" | "graph/passes/multi_batch_clone_pass.cc" | ||||
@@ -65,7 +65,7 @@ | |||||
#include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
#include "graph/passes/merge_input_memcpy_pass.h" | #include "graph/passes/merge_input_memcpy_pass.h" | ||||
#include "graph/passes/merge_to_stream_merge_pass.h" | #include "graph/passes/merge_to_stream_merge_pass.h" | ||||
#include "graph/passes/mark_branch_force_unknown_pass.h" | |||||
#include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
#include "graph/passes/multi_batch_pass.h" | #include "graph/passes/multi_batch_pass.h" | ||||
#include "graph/passes/next_iteration_pass.h" | #include "graph/passes/next_iteration_pass.h" | ||||
#include "graph/passes/permute_pass.h" | #include "graph/passes/permute_pass.h" | ||||
@@ -2537,8 +2537,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | ||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | ||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); | ||||
auto mark_force_unknown_pass = new (std::nothrow) MarkBranchForceUnknownPass; | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkBranchForceUnknownPass", mark_force_unknown_pass)); | |||||
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | ||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | ||||
GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -14,7 +14,7 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "mark_branch_force_unknown_pass.h" | |||||
#include "mark_force_unknown_for_cond_pass.h" | |||||
#include <queue> | #include <queue> | ||||
@@ -35,8 +35,8 @@ inline bool IsMergeInLoop(const NodePtr &node) { | |||||
} | } | ||||
} | } | ||||
Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("MarkBranchForceUnknownPass Enter"); | |||||
Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("MarkForceUnknownForCondPass Enter"); | |||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
std::string node_type; | std::string node_type; | ||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | ||||
@@ -58,7 +58,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { | |||||
MarkUnknownForSwitch(node); | MarkUnknownForSwitch(node); | ||||
} | } | ||||
GELOGD("MarkBranchForceUnknownPass Leave"); | |||||
GELOGD("MarkForceUnknownForCondPass Leave"); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -67,7 +67,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { | |||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @return | /// @return | ||||
/// | /// | ||||
void MarkBranchForceUnknownPass::MarkUnknownForSwitch(const NodePtr &node) { | |||||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { | |||||
// Switch --> {Switch --> Merge} --> Merge | // Switch --> {Switch --> Merge} --> Merge | ||||
std::vector<NodePtr> switch_group; | std::vector<NodePtr> switch_group; | ||||
std::unordered_set<NodePtr> nodes_seen; | std::unordered_set<NodePtr> nodes_seen; |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -14,13 +14,13 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ | |||||
#define GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ | |||||
#ifndef GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ | |||||
#define GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ | |||||
#include "inc/graph_pass.h" | #include "inc/graph_pass.h" | ||||
namespace ge { | namespace ge { | ||||
class MarkBranchForceUnknownPass : public GraphPass { | |||||
class MarkForceUnknownForCondPass : public GraphPass { | |||||
public: | public: | ||||
Status Run(ComputeGraphPtr graph); | Status Run(ComputeGraphPtr graph); | ||||
@@ -33,4 +33,4 @@ class MarkBranchForceUnknownPass : public GraphPass { | |||||
void MarkUnknownForSwitch(const NodePtr &node); | void MarkUnknownForSwitch(const NodePtr &node); | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_ | |||||
#endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ |
@@ -239,7 +239,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/mark_branch_force_unknown_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" | ||||
@@ -704,7 +704,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/net_output_pass_unittest.cc" | "graph/passes/net_output_pass_unittest.cc" | ||||
"graph/passes/no_use_reshape_remove_pass_unittest.cc" | "graph/passes/no_use_reshape_remove_pass_unittest.cc" | ||||
"graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
"graph/passes/mark_branch_force_unknown_pass_unittest.cc" | |||||
"graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | |||||
"graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
"graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
"graph/passes/link_gen_mask_nodes_pass_unittest.cc" | "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/mark_branch_force_unknown_pass.h" | |||||
#include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -29,7 +29,7 @@ | |||||
using namespace std; | using namespace std; | ||||
using namespace testing; | using namespace testing; | ||||
namespace ge { | namespace ge { | ||||
class UtestMarkBranchForceUnknownPass : public testing::Test { | |||||
class UtestMarkForceUnknownForCondPass : public testing::Test { | |||||
protected: | protected: | ||||
void SetUp() {} | void SetUp() {} | ||||
void TearDown() {} | void TearDown() {} | ||||
@@ -194,28 +194,28 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
merge = merge1; | merge = merge1; | ||||
} | } | ||||
TEST_F(UtestMarkBranchForceUnknownPass, skip_while_loop_merge) { | |||||
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
NodePtr merge; | NodePtr merge; | ||||
CreateLoopGraph(graph, merge); | CreateLoopGraph(graph, merge); | ||||
AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | ||||
MarkBranchForceUnknownPass mark_force_unknown_pass; | |||||
MarkForceUnknownForCondPass mark_force_unknown_pass; | |||||
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | ||||
} | } | ||||
TEST_F(UtestMarkBranchForceUnknownPass, skip_known_shape_merge) { | |||||
TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
NodePtr merge; | NodePtr merge; | ||||
CreateCondGraph(graph, merge); | CreateCondGraph(graph, merge); | ||||
MarkBranchForceUnknownPass mark_force_unknown_pass; | |||||
MarkForceUnknownForCondPass mark_force_unknown_pass; | |||||
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge | EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge | ||||
} | } | ||||
TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { | |||||
TEST_F(UtestMarkForceUnknownForCondPass, mark_unknown_shape_merge) { | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
NodePtr merge; | NodePtr merge; | ||||
CreateCondGraph(graph, merge); | CreateCondGraph(graph, merge); | ||||
@@ -224,7 +224,7 @@ TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) { | |||||
tensor_desc.SetShape(GeShape({-1})); // Set for unknown. | tensor_desc.SetShape(GeShape({-1})); // Set for unknown. | ||||
merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | ||||
MarkBranchForceUnknownPass mark_force_unknown_pass; | |||||
MarkForceUnknownForCondPass mark_force_unknown_pass; | |||||
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); | EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); | ||||
} | } | ||||
} // namespace ge | } // namespace ge |