Browse Source

UT for control flow group

tags/v1.5.1
zhangxiaokun 3 years ago
parent
commit
572990a616
23 changed files with 195 additions and 139 deletions
  1. +4
    -0
      tests/depends/mmpa/src/mmpa_stub.cc
  2. +1
    -2
      tests/ut/ge/CMakeLists.txt
  3. +5
    -5
      tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc
  4. +1
    -1
      tests/ut/ge/graph/build/stream_allocator_unittest.cc
  5. +3
    -3
      tests/ut/ge/graph/passes/assert_pass_unittest.cc
  6. +7
    -7
      tests/ut/ge/graph/passes/base_pass_unittest.cc
  7. +3
    -3
      tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc
  8. +19
    -19
      tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc
  9. +4
    -4
      tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc
  10. +1
    -1
      tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc
  11. +1
    -1
      tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc
  12. +91
    -38
      tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc
  13. +14
    -14
      tests/ut/ge/graph/passes/merge_pass_unittest.cc
  14. +6
    -6
      tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc
  15. +3
    -3
      tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc
  16. +8
    -8
      tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc
  17. +1
    -1
      tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc
  18. +6
    -6
      tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc
  19. +2
    -2
      tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc
  20. +7
    -7
      tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc
  21. +2
    -2
      tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc
  22. +1
    -1
      tests/ut/ge/graph/passes/variable_op_pass_unittest.cc
  23. +5
    -5
      tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc

+ 4
- 0
tests/depends/mmpa/src/mmpa_stub.cc View File

@@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName)

INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len)
{
const char *env = getenv(name);
if (env != nullptr) {
strcpy(value, env);
}
return 0;
}



+ 1
- 2
tests/ut/ge/CMakeLists.txt View File

@@ -720,7 +720,6 @@ set(PASS_TEST_FILES
"graph/passes/memcpy_addr_async_unittest.cc"
"graph/passes/hccl_continuous_pass_unittest.cc"
"graph/passes/hccl_memcpy_pass_unittest.cc"
)

set(KERNEL_TEST_FILES
@@ -850,7 +849,6 @@ set(HYBRID_TEST_FILES
"hybrid/executor/hybrid_model_async_executor_unittest.cc"
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc"
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc"

)

set(OTHERS_TEST_FILES
@@ -877,6 +875,7 @@ add_library(ge_ut_graph STATIC

target_compile_definitions(ge_ut_graph PRIVATE
google=ascend_private
FMK_SUPPORT_DUMP
)

target_compile_options(ge_ut_graph PRIVATE


+ 5
- 5
tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc View File

@@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test {
/// B --> C(AllReduce) --- D
/// /
/// stream id: 0 A
/// \
/// \.
/// E --> F(AllReduce) --- G
/// stream id: 2 2 2
///
@@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) {

/// case of multi-output, then unuse stream
/// sub1
/// / | \
/// / | \.
/// sub2 sub3 sub4
TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) {
SubGraphInfoPtr data = CreateDataSubgraph();
@@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) {

/// if paralle id 1, then use stream
/// sub1
/// / | | \
/// / | | \.
/// sub2 sub3 sub4 sub5
TEST_F(UtestLogicalStreamAllocator, test_parallel_one) {
SubGraphInfoPtr data = CreateDataSubgraph();
@@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) {

/// if the param of engine independent is true, then set independent stream
/// sub1
/// / | | \
/// / | | \.
/// sub2 sub3 sub4 sub5
TEST_F(UtestLogicalStreamAllocator, test_independent) {
SubGraphInfoPtr data = CreateDataSubgraph();
@@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) {

/// set stream based on stream label, and then based on independent
/// sub1
/// / | | \
/// / | | \.
/// sub2 sub3 sub4 sub5
TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) {
SubGraphInfoPtr data = CreateDataSubgraph();


+ 1
- 1
tests/ut/ge/graph/build/stream_allocator_unittest.cc View File

@@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test {

///
/// A
/// / \
/// / \.
/// B C
/// | |
/// D 400


+ 3
- 3
tests/ut/ge/graph/passes/assert_pass_unittest.cc View File

@@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test {
};

/// D E
/// | \ | \
/// | \ | \.
/// F C G
/// : | :
/// H A I
@@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) {
EXPECT_EQ(graph->FindNode("D"), nullptr);
}

/// E F
/// | \ | \
/// E F
/// | \ | \.
/// H C -> D G
/// \ | :
/// A I


+ 7
- 7
tests/ut/ge/graph/passes/base_pass_unittest.cc View File

@@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test {
/// reshape1
/// |
/// add1
/// / \
/// / \.
/// | |
/// data1 const1
ComputeGraphPtr BuildGraph1() {
@@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() {
}

/// sum1
/// / \
/// / \
/// / \
/// / \.
/// / \.
/// / \.
/// reshape1 addn1
/// | c |
/// add1 <--- shape1
@@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str
/// Op1
/// |
/// Merge
/// / \
/// / \.
/// Op2 Op3
TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) {
auto builder = ut::GraphBuilder("g1");
@@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) {
/// Op1
/// |
/// Merge
/// / \
/// / \.
/// Op2 Op3
TEST_F(UTESTGraphPassesBasePass, del_isolate_success) {
auto builder = ut::GraphBuilder("g1");
@@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) {
/// data1 const
/// \ /
/// while
/// / \
/// / \.
/// | |
/// cast1 cast2
ComputeGraphPtr BuildWhileGraph1() {


+ 3
- 3
tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc View File

@@ -34,11 +34,11 @@ namespace {
/// net_output
/// |
/// merge
/// / \
/// / \.
/// square add
/// F| T/ T\
/// F| T/ T\.
/// switch1 switch2
/// / \ / \
/// / \ / \.
/// var1 var2 var3
///
ComputeGraphPtr BuildGraph1() {


+ 19
- 19
tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc View File

@@ -173,8 +173,8 @@ namespace {
/// shapeNo1
/// |
/// addnYes1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2
ComputeGraphPtr BuildGraph1() {
auto builder = ut::GraphBuilder("test");
@@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() {
/// shapeNo1
/// | c
/// addnYes1 <----- dataNo1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2
ComputeGraphPtr BuildGraph3() {
auto builder = ut::GraphBuilder("test");
@@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() {
/// shapeNo1
/// | c
/// addnYes1 <---------
/// / \ \
/// / \ c \
/// / \ \.
/// / \ c \.
/// const1 const2 <----- dataNo1
ComputeGraphPtr BuildGraph4() {
auto builder = ut::GraphBuilder("test");
@@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() {
/// shapeNo1
/// | c
/// addnYes1 <----- dataNo1
/// / \
/// / \.
/// / \ c
/// const1 const2 <----- dataNo2
ComputeGraphPtr BuildGraph5() {
@@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() {
/// addYes1 <---- const3
/// |
/// addnYes1 <-
/// / \ \
/// / \ \
/// / \ \.
/// / \ \.
/// const1 const2 const4
ComputeGraphPtr BuildGraph6() {
auto builder = ut::GraphBuilder("test");
@@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() {
}

/// netoutput1
/// / \
/// / \.
/// shapeNo1 ShpaeNo2
/// \ /
/// huberLoss1
/// / | \
/// / | \
/// / | \.
/// / | \.
/// const1 const2 const3
ComputeGraphPtr BuildGraph7() {
auto builder = ut::GraphBuilder("test");
@@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() {
/// shapeNo1
/// |
/// addnNo1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2
ComputeGraphPtr BuildGraph8() {
auto builder = ut::GraphBuilder("test");
@@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() {
/// shapeNo1
/// |
/// addnYes1
/// / \
/// / \
/// / \.
/// / \.
/// const1 data1
ComputeGraphPtr BuildGraph9() {
auto builder = ut::GraphBuilder("test");
@@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() {
}

/// netoutput1
/// / \
/// / \.
/// addDim sqrt1
/// \ /
/// switch1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2
ComputeGraphPtr BuildGraph10() {
auto builder = ut::GraphBuilder("test");


+ 4
- 4
tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc View File

@@ -63,8 +63,8 @@ namespace {
/// shapeNo1
/// |
/// addnNo1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2
ComputeGraphPtr BuildGraph8() {
auto builder = ut::GraphBuilder("test");
@@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() {
/// shapeNo1
/// |
/// addnYes1
/// / \
/// / \
/// / \.
/// / \.
///const1 data1
ComputeGraphPtr BuildGraph9() {
auto builder = ut::GraphBuilder("test");


+ 1
- 1
tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc View File

@@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test {
/// convolution data
/// | /
/// ssdpriorbox
/// \
/// \.
/// reshape
class NodeBuilder {
public:


+ 1
- 1
tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc View File

@@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {

/// graph with subgraph
/// const
/// / \
/// / \.
/// cast1 cast1
/// \ /
/// case


+ 91
- 38
tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc View File

@@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string
return graph.AddNode(op_desc);
}

static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) {
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) {
/*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
* |
* +--------------------- Merge ----------------------+
* / |
* / |
* / |
* / |
* Exit Identify |
* \ / \. |
* \ / \. |
* Switch Add Add
* / | | |
* / | | |
* / | | |
* LoopCond | | |
* \ | | |
* \ | | |
* \ | | |
* Less | | |
* \ | NextIteration |
* \ | | |
* \ | | |
* Merge <---------| |
* | |
* | |
* Enter |
* \ |
* \ |
* Switch Switch
* | |
* +-----------------Equal----------------------+
* |
******************************************************************************/
auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto data1 = CreateNode(*graph, "data1", DATA, 1, 1);
auto data2 = CreateNode(*graph, "data2", DATA, 1, 1);

auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1);
auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2);
auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2);

auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less1", LESS, 2, 1);
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto add1 = CreateNode(*graph, "add1", ADD, 2, 1);
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1);

auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1);
auto add2 = CreateNode(*graph, "add2", ADD, 2, 1);
auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0));
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1));
cond.emplace_back(switch1);
cond.emplace_back(switch2);

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));

GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1));
loop.emplace_back(merge1);

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false
GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true
loop.emplace_back(switch3);

GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));

GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

merge = merge1;
GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true
GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0));

GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0));
GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1));
GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

cond.emplace_back(merge2);
merge = merge2;
}

static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) {
@@ -197,12 +235,27 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) {
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge;
CreateLoopGraph(graph, merge);
AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);
vector<NodePtr> loop;
vector<NodePtr> cond;
CreateLoopGraph(graph, merge, loop, cond);

MarkForceUnknownForCondPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond
setenv("DUMP_GE_GRAPH", "1", true);
GE_DUMP(graph, "control_group");
unsetenv("DUMP_GE_GRAPH");

EXPECT_EQ(loop.size(), 2);
for (const auto &node : loop) {
EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP));
}

EXPECT_EQ(cond.size(), 3);
for (const auto &node : cond) {
int64_t group_index = -1;
EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index));
EXPECT_EQ(group_index, merge->GetOpDesc()->GetId());
}
}

TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) {


+ 14
- 14
tests/ut/ge/graph/passes/merge_pass_unittest.cc View File

@@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) {
}

/// Merge
/// | \
/// | \
/// | \.
/// | \.
/// Op1 Op2 Merge2
/// \ | |
/// \ | Op3
@@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da
}

/// Merge
/// | \
/// | \
/// | \.
/// | \.
/// Op1 Op2 Merge2
/// \ | | \
/// \ | | \.
/// \ | Op3
/// \ | :
/// NetOutput
@@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co

TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) {
/// Merge
/// | \
/// | \
/// | \.
/// | \.
/// Op1 Op2 Merge2
/// \ | |
/// \ | Op3
@@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) {
/// Op1 Op2 Merge2
/// \ |
/// \ Op3
/// \
/// \.
/// Merge3

ret = pass_.Run(merge_node2);
@@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) {
/// Op1
/// |
/// Merge
/// / \
/// / \.
/// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto node1 = NewNode("Op1", RELU, 1, 1);
@@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) {
/// Const
/// |
/// Merge Pass Const
/// / \ ===> / \
/// / \ ===> / \.
/// Op1 Op2 Op1 Op2
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
@@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes)
/// / | ===> / \(control anchor)
/// Op1 | \ Op1 Constant
/// Op2 Op3 |
/// / \
/// / \.
/// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
@@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1)
/// / | ===> / \(control anchor)
/// Op1 | \ Op1 Constant
/// Op2 Op3 |
/// / \
/// / \.
/// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
@@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) {
/// C
/// |
/// Merge
/// / \
/// / \.
/// Op1 Op2
auto switch_node = NewNode("Switch", SWITCH, 1, 2);
auto identity_node = NewNode("Identity", SWITCH, 1, 1);
@@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) {
/// .
/// .
/// C
/// / \
/// / \.
/// Op1 Op2
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);


+ 6
- 6
tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc View File

@@ -66,11 +66,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {

void BuildDefaultGraph() {
/// input
/// \
/// \.
/// sqrt pred
/// \ /
/// cast
/// / \
/// / \.
/// switch_t switch_f
/// | |
/// F T
@@ -118,13 +118,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {

void BuildDefaultGraph1() {
/// input
/// \
/// \.
/// sqrt pred
/// \ /
/// Switch
/// | |
/// ----F T----
/// \ | / \
/// \ | / \.
/// \ Merge1 Merge2
/// \_________|
input_node_ = NewNode("input", RELU, 0, 1);
@@ -164,14 +164,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {

void BuildDefaultGraph2() {
/// input input1
/// \ \
/// \ \.
/// sqrt pred sqrt1 pred1
/// \ / \ /
/// Switch Switch1
/// | | _______|
/// | | /
/// ____F T____
/// \ | / \
/// \ | / \.
/// \ Merge1 Merge2
/// \__________|
input_node_ = NewNode("input", RELU, 0, 2);


+ 3
- 3
tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc View File

@@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test {

namespace {
/// netoutput1
/// | \
///transdata1 \
/// | \
/// | \.
///transdata1 \.
/// | \.
/// | transdata2
/// | /
/// var1 const1


+ 8
- 8
tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc View File

@@ -35,7 +35,7 @@ namespace {
/// transdata1
/// |
/// reshape1
/// | \
/// | \.
/// var1 const1
ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
@@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() {
}

/// netoutput1
/// | \
///transdata1 \
/// | \
/// | \.
///transdata1 \.
/// | \.
/// reshape1 reshape2
/// | \ / \
/// | \ / \.
/// var1 const1 var2
ut::GraphBuilder Graph2Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g2");
@@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() {
}

/// netoutput1
/// | \
///transdata1 \
/// | \
/// | \.
///transdata1 \.
/// | \.
/// reshape1 transdata2
/// | \ /
/// var1 const1


+ 1
- 1
tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc View File

@@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test {

namespace {
/// netoutput1
/// | \
/// | \.
/// StackPush StackPop
/// | |
/// var1 const1


+ 6
- 6
tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc View File

@@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() {
/// netoutput1
/// |
/// merge1
/// / \
/// / \.
/// / add1
/// / F| \
/// / F| \.
/// addn1 swtich2 var3
/// \F T/ |
/// switch1 |
@@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() {
/// add1
/// / \T
/// var3 swtich2
/// T/ \
/// switch1 \
/// / \ \
/// T/ \.
/// switch1 \.
/// / \ \.
/// var1 var2 var4
ComputeGraphPtr BuildGraph3() {
auto builder = ut::GraphBuilder("g3");
@@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() {
/// netoutput1
/// |
/// merge1
/// / \
/// / \.
/// add1 addn1
/// / \T F/
/// var3 swtich2


+ 2
- 2
tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc View File

@@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) {
}

/// ----> netoutput1
/// / | \
/// / | \.
/// transdata1 transdata2 transdata3
/// \ / |
/// var1--------------
@@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() {
}

/// ---------> netoutput1
/// / | \
/// / | \.
/// transdata1 transdata2(l1) transdata3(l1)
/// \ / |
/// var1------------------


+ 7
- 7
tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc View File

@@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge)
/// -->transpose1 -->transpose3-->sinh2
/// | \ /
/// | -->transpose2
/// | \
/// | \.
/// / -->cast3-->cast4-->sinh3
/// /
/// / -->transpose4-->transpose5-->sinh4
/// / /
/// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5
/// \ \
/// \ \.
/// \ -->sinh6
/// \
/// \.
/// \ -->transpose6-->transpose7-->sinh9
/// \ /
/// -->reshape-->cast6-->cast7-->sinh8
/// \
/// \.
/// -->sinh7

/// after optimized graph
@@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge)
/// / /-->transpose3-->sinh2
/// -->Cast1
/// / \-->sinh7
/// / \
/// / \.
/// / -->sinh9
/// Node4D
/// \ -->sinh4
/// \ /
/// -->Cast5-->sinh5
/// \ \
/// \ \.
/// \ -->sinh6
/// \
/// \.
/// -->Cast7-->sinh8
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");



+ 2
- 2
tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc View File

@@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran
/// TransData TransData ... MatMul ...
/// \ | / / /
/// HcomAllReduce
/// / | \ \ \
/// / | \ \ \.
/// TransData TransData ... RealDiv ...
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr allreduce =
@@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans
/// TransData TransData ... MatMul ...
/// \ | / / /
/// HcomAllReduce
/// / | \ \ \
/// / | \ \ \.
/// TransData TransData ... RealDiv ...
size_t symmetric_transdata_num = 20;
size_t asymmetric_transdata_num = 20;


+ 1
- 1
tests/ut/ge/graph/passes/variable_op_pass_unittest.cc View File

@@ -66,7 +66,7 @@ namespace {
/// transdata2
/// |
/// assign1
/// / \
/// / \.
/// transdata1 |
/// | |
/// var1 const1


+ 5
- 5
tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc View File

@@ -35,8 +35,8 @@ namespace {
/// shapeNo1
/// |
/// addnYes1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2

ComputeGraphPtr BuildGraph1() {
@@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() {

///
/// netoutput1
/// / \ \
/// add1 assign1 \
/// / \ / \ \
/// / \ \.
/// add1 assign1 \.
/// / \ / \ \.
/// var1 var2 const1 var3

ComputeGraphPtr BuildGraph2() {


Loading…
Cancel
Save