Browse Source

Rename mark_branch_force_unknown_pass

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
a4aae9c691
6 changed files with 27 additions and 27 deletions
  1. +2
    -2
      ge/CMakeLists.txt
  2. +3
    -3
      ge/graph/manager/graph_manager.cc
  3. +6
    -6
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  4. +5
    -5
      ge/graph/passes/mark_force_unknown_for_cond_pass.h
  5. +2
    -2
      tests/ut/ge/CMakeLists.txt
  6. +9
    -9
      tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc

+ 2
- 2
ge/CMakeLists.txt View File

@@ -307,7 +307,7 @@ set(TRAIN_SRC_LIST
"graph/passes/merge_to_stream_merge_pass.cc"
"graph/passes/merge_input_memcpy_pass.cc"
"graph/passes/switch_to_stream_switch_pass.cc"
"graph/passes/mark_branch_force_unknown_pass.cc"
"graph/passes/mark_force_unknown_for_cond_pass.cc"
"graph/passes/attach_stream_label_pass.cc"
"graph/passes/switch_dead_branch_elimination.cc"
"graph/passes/replace_transshape_pass.cc"
@@ -585,7 +585,7 @@ set(INFER_SRC_LIST
"graph/passes/merge_to_stream_merge_pass.cc"
"graph/passes/merge_input_memcpy_pass.cc"
"graph/passes/switch_to_stream_switch_pass.cc"
"graph/passes/mark_branch_force_unknown_pass.cc"
"graph/passes/mark_force_unknown_for_cond_pass.cc"
"graph/passes/attach_stream_label_pass.cc"
"graph/passes/multi_batch_pass.cc"
"graph/passes/multi_batch_clone_pass.cc"


+ 3
- 3
ge/graph/manager/graph_manager.cc View File

@@ -65,7 +65,7 @@
#include "graph/passes/merge_pass.h"
#include "graph/passes/merge_input_memcpy_pass.h"
#include "graph/passes/merge_to_stream_merge_pass.h"
#include "graph/passes/mark_branch_force_unknown_pass.h"
#include "graph/passes/mark_force_unknown_for_cond_pass.h"
#include "graph/passes/multi_batch_pass.h"
#include "graph/passes/next_iteration_pass.h"
#include "graph/passes/permute_pass.h"
@@ -2537,8 +2537,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass));
auto mark_force_unknown_pass = new (std::nothrow) MarkBranchForceUnknownPass;
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkBranchForceUnknownPass", mark_force_unknown_pass));
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass;
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass))
GE_CHK_STATUS_RET(


ge/graph/passes/mark_branch_force_unknown_pass.cc → ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "mark_branch_force_unknown_pass.h"
#include "mark_force_unknown_for_cond_pass.h"

#include <queue>

@@ -35,8 +35,8 @@ inline bool IsMergeInLoop(const NodePtr &node) {
}
}

Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkBranchForceUnknownPass Enter");
Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkForceUnknownForCondPass Enter");
for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
@@ -58,7 +58,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) {
MarkUnknownForSwitch(node);
}

GELOGD("MarkBranchForceUnknownPass Leave");
GELOGD("MarkForceUnknownForCondPass Leave");
return SUCCESS;
}

@@ -67,7 +67,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) {
/// @param [in] merge node
/// @return
///
void MarkBranchForceUnknownPass::MarkUnknownForSwitch(const NodePtr &node) {
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) {
// Switch --> {Switch --> Merge} --> Merge
std::vector<NodePtr> switch_group;
std::unordered_set<NodePtr> nodes_seen;

ge/graph/passes/mark_branch_force_unknown_pass.h → ge/graph/passes/mark_force_unknown_for_cond_pass.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,13 +14,13 @@
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_
#define GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_
#ifndef GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_
#define GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_

#include "inc/graph_pass.h"

namespace ge {
class MarkBranchForceUnknownPass : public GraphPass {
class MarkForceUnknownForCondPass : public GraphPass {
public:
Status Run(ComputeGraphPtr graph);

@@ -33,4 +33,4 @@ class MarkBranchForceUnknownPass : public GraphPass {
void MarkUnknownForSwitch(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MARK_BRANCH_FORCE_UNKNOWN_PASS_H_
#endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_

+ 2
- 2
tests/ut/ge/CMakeLists.txt View File

@@ -239,7 +239,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/mark_branch_force_unknown_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc"
@@ -704,7 +704,7 @@ set(PASS_TEST_FILES
"graph/passes/net_output_pass_unittest.cc"
"graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc"
"graph/passes/mark_branch_force_unknown_pass_unittest.cc"
"graph/passes/mark_force_unknown_for_cond_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/replace_with_empty_const_pass_unittest.cc"
"graph/passes/link_gen_mask_nodes_pass_unittest.cc"


tests/ut/ge/graph/passes/mark_branch_force_unknown_pass_unittest.cc → tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@

#define protected public
#define private public
#include "graph/passes/mark_branch_force_unknown_pass.h"
#include "graph/passes/mark_force_unknown_for_cond_pass.h"

#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
@@ -29,7 +29,7 @@
using namespace std;
using namespace testing;
namespace ge {
class UtestMarkBranchForceUnknownPass : public testing::Test {
class UtestMarkForceUnknownForCondPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
@@ -194,28 +194,28 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) {
merge = merge1;
}

TEST_F(UtestMarkBranchForceUnknownPass, skip_while_loop_merge) {
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateLoopGraph(graph, merge);

AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);

MarkBranchForceUnknownPass mark_force_unknown_pass;
MarkForceUnknownForCondPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond
}

TEST_F(UtestMarkBranchForceUnknownPass, skip_known_shape_merge) {
TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateCondGraph(graph, merge);

MarkBranchForceUnknownPass mark_force_unknown_pass;
MarkForceUnknownForCondPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge
}


TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) {
TEST_F(UtestMarkForceUnknownForCondPass, mark_unknown_shape_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateCondGraph(graph, merge);
@@ -224,7 +224,7 @@ TEST_F(UtestMarkBranchForceUnknownPass, mark_unknown_shape_merge) {
tensor_desc.SetShape(GeShape({-1})); // Set for unknown.
merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);

MarkBranchForceUnknownPass mark_force_unknown_pass;
MarkForceUnknownForCondPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save