@@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) | |||
INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | |||
{ | |||
const char *env = getenv(name); | |||
if (env != nullptr) { | |||
strcpy(value, env); | |||
} | |||
return 0; | |||
} | |||
@@ -720,7 +720,6 @@ set(PASS_TEST_FILES | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||
) | |||
set(KERNEL_TEST_FILES | |||
@@ -850,7 +849,6 @@ set(HYBRID_TEST_FILES | |||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | |||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | |||
) | |||
set(OTHERS_TEST_FILES | |||
@@ -877,6 +875,7 @@ add_library(ge_ut_graph STATIC | |||
target_compile_definitions(ge_ut_graph PRIVATE | |||
google=ascend_private | |||
FMK_SUPPORT_DUMP | |||
) | |||
target_compile_options(ge_ut_graph PRIVATE | |||
@@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||
/// B --> C(AllReduce) --- D | |||
/// / | |||
/// stream id: 0 A | |||
/// \ | |||
/// \. | |||
/// E --> F(AllReduce) --- G | |||
/// stream id: 2 2 2 | |||
/// | |||
@@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { | |||
/// case of multi-output, then unuse stream | |||
/// sub1 | |||
/// / | \ | |||
/// / | \. | |||
/// sub2 sub3 sub4 | |||
TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||
SubGraphInfoPtr data = CreateDataSubgraph(); | |||
@@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||
/// if paralle id 1, then use stream | |||
/// sub1 | |||
/// / | | \ | |||
/// / | | \. | |||
/// sub2 sub3 sub4 sub5 | |||
TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||
SubGraphInfoPtr data = CreateDataSubgraph(); | |||
@@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||
/// if the param of engine independent is true, then set independent stream | |||
/// sub1 | |||
/// / | | \ | |||
/// / | | \. | |||
/// sub2 sub3 sub4 sub5 | |||
TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||
SubGraphInfoPtr data = CreateDataSubgraph(); | |||
@@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||
/// set stream based on stream label, and then based on independent | |||
/// sub1 | |||
/// / | | \ | |||
/// / | | \. | |||
/// sub2 sub3 sub4 sub5 | |||
TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | |||
SubGraphInfoPtr data = CreateDataSubgraph(); | |||
@@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { | |||
/// | |||
/// A | |||
/// / \ | |||
/// / \. | |||
/// B C | |||
/// | | | |||
/// D 400 | |||
@@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { | |||
}; | |||
/// D E | |||
/// | \ | \ | |||
/// | \ | \. | |||
/// F C G | |||
/// : | : | |||
/// H A I | |||
@@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { | |||
EXPECT_EQ(graph->FindNode("D"), nullptr); | |||
} | |||
/// E F | |||
/// | \ | \ | |||
/// E F | |||
/// | \ | \. | |||
/// H C -> D G | |||
/// \ | : | |||
/// A I | |||
@@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { | |||
/// reshape1 | |||
/// | | |||
/// add1 | |||
/// / \ | |||
/// / \. | |||
/// | | | |||
/// data1 const1 | |||
ComputeGraphPtr BuildGraph1() { | |||
@@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { | |||
} | |||
/// sum1 | |||
/// / \ | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// / \. | |||
/// reshape1 addn1 | |||
/// | c | | |||
/// add1 <--- shape1 | |||
@@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str | |||
/// Op1 | |||
/// | | |||
/// Merge | |||
/// / \ | |||
/// / \. | |||
/// Op2 Op3 | |||
TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||
auto builder = ut::GraphBuilder("g1"); | |||
@@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||
/// Op1 | |||
/// | | |||
/// Merge | |||
/// / \ | |||
/// / \. | |||
/// Op2 Op3 | |||
TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | |||
auto builder = ut::GraphBuilder("g1"); | |||
@@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||
/// data1 const | |||
/// \ / | |||
/// while | |||
/// / \ | |||
/// / \. | |||
/// | | | |||
/// cast1 cast2 | |||
ComputeGraphPtr BuildWhileGraph1() { | |||
@@ -34,11 +34,11 @@ namespace { | |||
/// net_output | |||
/// | | |||
/// merge | |||
/// / \ | |||
/// / \. | |||
/// square add | |||
/// F| T/ T\ | |||
/// F| T/ T\. | |||
/// switch1 switch2 | |||
/// / \ / \ | |||
/// / \ / \. | |||
/// var1 var2 var3 | |||
/// | |||
ComputeGraphPtr BuildGraph1() { | |||
@@ -173,8 +173,8 @@ namespace { | |||
/// shapeNo1 | |||
/// | | |||
/// addnYes1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 const2 | |||
ComputeGraphPtr BuildGraph1() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { | |||
/// shapeNo1 | |||
/// | c | |||
/// addnYes1 <----- dataNo1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 const2 | |||
ComputeGraphPtr BuildGraph3() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { | |||
/// shapeNo1 | |||
/// | c | |||
/// addnYes1 <--------- | |||
/// / \ \ | |||
/// / \ c \ | |||
/// / \ \. | |||
/// / \ c \. | |||
/// const1 const2 <----- dataNo1 | |||
ComputeGraphPtr BuildGraph4() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { | |||
/// shapeNo1 | |||
/// | c | |||
/// addnYes1 <----- dataNo1 | |||
/// / \ | |||
/// / \. | |||
/// / \ c | |||
/// const1 const2 <----- dataNo2 | |||
ComputeGraphPtr BuildGraph5() { | |||
@@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { | |||
/// addYes1 <---- const3 | |||
/// | | |||
/// addnYes1 <- | |||
/// / \ \ | |||
/// / \ \ | |||
/// / \ \. | |||
/// / \ \. | |||
/// const1 const2 const4 | |||
ComputeGraphPtr BuildGraph6() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { | |||
} | |||
/// netoutput1 | |||
/// / \ | |||
/// / \. | |||
/// shapeNo1 ShpaeNo2 | |||
/// \ / | |||
/// huberLoss1 | |||
/// / | \ | |||
/// / | \ | |||
/// / | \. | |||
/// / | \. | |||
/// const1 const2 const3 | |||
ComputeGraphPtr BuildGraph7() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { | |||
/// shapeNo1 | |||
/// | | |||
/// addnNo1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 const2 | |||
ComputeGraphPtr BuildGraph8() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { | |||
/// shapeNo1 | |||
/// | | |||
/// addnYes1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 data1 | |||
ComputeGraphPtr BuildGraph9() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { | |||
} | |||
/// netoutput1 | |||
/// / \ | |||
/// / \. | |||
/// addDim sqrt1 | |||
/// \ / | |||
/// switch1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 const2 | |||
ComputeGraphPtr BuildGraph10() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -63,8 +63,8 @@ namespace { | |||
/// shapeNo1 | |||
/// | | |||
/// addnNo1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 const2 | |||
ComputeGraphPtr BuildGraph8() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { | |||
/// shapeNo1 | |||
/// | | |||
/// addnYes1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
///const1 data1 | |||
ComputeGraphPtr BuildGraph9() { | |||
auto builder = ut::GraphBuilder("test"); | |||
@@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test { | |||
/// convolution data | |||
/// | / | |||
/// ssdpriorbox | |||
/// \ | |||
/// \. | |||
/// reshape | |||
class NodeBuilder { | |||
public: | |||
@@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||
/// graph with subgraph | |||
/// const | |||
/// / \ | |||
/// / \. | |||
/// cast1 cast1 | |||
/// \ / | |||
/// case | |||
@@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||
return graph.AddNode(op_desc); | |||
} | |||
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) { | |||
/******************************************************************************* | |||
* Exit Identify | |||
* \ / \. | |||
* \ / \. | |||
* Switch Add | |||
* / | | | |||
* / | | | |||
* / | | | |||
* LoopCond | | | |||
* \ | | | |||
* \ | | | |||
* \ | | | |||
* Less | | | |||
* \ | NextIteration | |||
* \ | | | |||
* \ | | | |||
* Merge <---------| | |||
* | | |||
* | | |||
* Enter | |||
* | | |||
* +--------------------- Merge ----------------------+ | |||
* / | | |||
* / | | |||
* / | | |||
* / | | |||
* Exit Identify | | |||
* \ / \. | | |||
* \ / \. | | |||
* Switch Add Add | |||
* / | | | | |||
* / | | | | |||
* / | | | | |||
* LoopCond | | | | |||
* \ | | | | |||
* \ | | | | |||
* \ | | | | |||
* Less | | | | |||
* \ | NextIteration | | |||
* \ | | | | |||
* \ | | | | |||
* Merge <---------| | | |||
* | | | |||
* | | | |||
* Enter | | |||
* \ | | |||
* \ | | |||
* Switch Switch | |||
* | | | |||
* +-----------------Equal----------------------+ | |||
* | | |||
******************************************************************************/ | |||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||
auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||
auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); | |||
auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); | |||
auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); | |||
auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); | |||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||
auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); | |||
auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); | |||
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||
auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); | |||
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||
auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||
auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); | |||
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||
auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||
auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||
auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); | |||
auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); | |||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); | |||
cond.emplace_back(switch1); | |||
cond.emplace_back(switch2); | |||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false | |||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); | |||
loop.emplace_back(merge1); | |||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false | |||
GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true | |||
loop.emplace_back(switch3); | |||
GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
merge = merge1; | |||
GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true | |||
GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
cond.emplace_back(merge2); | |||
merge = merge2; | |||
} | |||
static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||
@@ -197,12 +235,27 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | |||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||
NodePtr merge; | |||
CreateLoopGraph(graph, merge); | |||
AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||
vector<NodePtr> loop; | |||
vector<NodePtr> cond; | |||
CreateLoopGraph(graph, merge, loop, cond); | |||
MarkForceUnknownForCondPass mark_force_unknown_pass; | |||
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | |||
setenv("DUMP_GE_GRAPH", "1", true); | |||
GE_DUMP(graph, "control_group"); | |||
unsetenv("DUMP_GE_GRAPH"); | |||
EXPECT_EQ(loop.size(), 2); | |||
for (const auto &node : loop) { | |||
EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); | |||
} | |||
EXPECT_EQ(cond.size(), 3); | |||
for (const auto &node : cond) { | |||
int64_t group_index = -1; | |||
EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); | |||
EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); | |||
} | |||
} | |||
TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | |||
@@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { | |||
} | |||
/// Merge | |||
/// | \ | |||
/// | \ | |||
/// | \. | |||
/// | \. | |||
/// Op1 Op2 Merge2 | |||
/// \ | | | |||
/// \ | Op3 | |||
@@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da | |||
} | |||
/// Merge | |||
/// | \ | |||
/// | \ | |||
/// | \. | |||
/// | \. | |||
/// Op1 Op2 Merge2 | |||
/// \ | | \ | |||
/// \ | | \. | |||
/// \ | Op3 | |||
/// \ | : | |||
/// NetOutput | |||
@@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co | |||
TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||
/// Merge | |||
/// | \ | |||
/// | \ | |||
/// | \. | |||
/// | \. | |||
/// Op1 Op2 Merge2 | |||
/// \ | | | |||
/// \ | Op3 | |||
@@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||
/// Op1 Op2 Merge2 | |||
/// \ | | |||
/// \ Op3 | |||
/// \ | |||
/// \. | |||
/// Merge3 | |||
ret = pass_.Run(merge_node2); | |||
@@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { | |||
/// Op1 | |||
/// | | |||
/// Merge | |||
/// / \ | |||
/// / \. | |||
/// Op2 Op3 | |||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
auto node1 = NewNode("Op1", RELU, 1, 1); | |||
@@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { | |||
/// Const | |||
/// | | |||
/// Merge Pass Const | |||
/// / \ ===> / \ | |||
/// / \ ===> / \. | |||
/// Op1 Op2 Op1 Op2 | |||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | |||
@@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) | |||
/// / | ===> / \(control anchor) | |||
/// Op1 | \ Op1 Constant | |||
/// Op2 Op3 | | |||
/// / \ | |||
/// / \. | |||
/// Op2 Op3 | |||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | |||
@@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) | |||
/// / | ===> / \(control anchor) | |||
/// Op1 | \ Op1 Constant | |||
/// Op2 Op3 | | |||
/// / \ | |||
/// / \. | |||
/// Op2 Op3 | |||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | |||
@@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||
/// C | |||
/// | | |||
/// Merge | |||
/// / \ | |||
/// / \. | |||
/// Op1 Op2 | |||
auto switch_node = NewNode("Switch", SWITCH, 1, 2); | |||
auto identity_node = NewNode("Identity", SWITCH, 1, 1); | |||
@@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||
/// . | |||
/// . | |||
/// C | |||
/// / \ | |||
/// / \. | |||
/// Op1 Op2 | |||
auto ret = pass_.Run(merge_node); | |||
EXPECT_EQ(ret, SUCCESS); | |||
@@ -66,11 +66,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||
void BuildDefaultGraph() { | |||
/// input | |||
/// \ | |||
/// \. | |||
/// sqrt pred | |||
/// \ / | |||
/// cast | |||
/// / \ | |||
/// / \. | |||
/// switch_t switch_f | |||
/// | | | |||
/// F T | |||
@@ -118,13 +118,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||
void BuildDefaultGraph1() { | |||
/// input | |||
/// \ | |||
/// \. | |||
/// sqrt pred | |||
/// \ / | |||
/// Switch | |||
/// | | | |||
/// ----F T---- | |||
/// \ | / \ | |||
/// \ | / \. | |||
/// \ Merge1 Merge2 | |||
/// \_________| | |||
input_node_ = NewNode("input", RELU, 0, 1); | |||
@@ -164,14 +164,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||
void BuildDefaultGraph2() { | |||
/// input input1 | |||
/// \ \ | |||
/// \ \. | |||
/// sqrt pred sqrt1 pred1 | |||
/// \ / \ / | |||
/// Switch Switch1 | |||
/// | | _______| | |||
/// | | / | |||
/// ____F T____ | |||
/// \ | / \ | |||
/// \ | / \. | |||
/// \ Merge1 Merge2 | |||
/// \__________| | |||
input_node_ = NewNode("input", RELU, 0, 2); | |||
@@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { | |||
namespace { | |||
/// netoutput1 | |||
/// | \ | |||
///transdata1 \ | |||
/// | \ | |||
/// | \. | |||
///transdata1 \. | |||
/// | \. | |||
/// | transdata2 | |||
/// | / | |||
/// var1 const1 | |||
@@ -35,7 +35,7 @@ namespace { | |||
/// transdata1 | |||
/// | | |||
/// reshape1 | |||
/// | \ | |||
/// | \. | |||
/// var1 const1 | |||
ut::GraphBuilder Graph1Builder() { | |||
ut::GraphBuilder builder = ut::GraphBuilder("g1"); | |||
@@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { | |||
} | |||
/// netoutput1 | |||
/// | \ | |||
///transdata1 \ | |||
/// | \ | |||
/// | \. | |||
///transdata1 \. | |||
/// | \. | |||
/// reshape1 reshape2 | |||
/// | \ / \ | |||
/// | \ / \. | |||
/// var1 const1 var2 | |||
ut::GraphBuilder Graph2Builder() { | |||
ut::GraphBuilder builder = ut::GraphBuilder("g2"); | |||
@@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { | |||
} | |||
/// netoutput1 | |||
/// | \ | |||
///transdata1 \ | |||
/// | \ | |||
/// | \. | |||
///transdata1 \. | |||
/// | \. | |||
/// reshape1 transdata2 | |||
/// | \ / | |||
/// var1 const1 | |||
@@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { | |||
namespace { | |||
/// netoutput1 | |||
/// | \ | |||
/// | \. | |||
/// StackPush StackPop | |||
/// | | | |||
/// var1 const1 | |||
@@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { | |||
/// netoutput1 | |||
/// | | |||
/// merge1 | |||
/// / \ | |||
/// / \. | |||
/// / add1 | |||
/// / F| \ | |||
/// / F| \. | |||
/// addn1 swtich2 var3 | |||
/// \F T/ | | |||
/// switch1 | | |||
@@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { | |||
/// add1 | |||
/// / \T | |||
/// var3 swtich2 | |||
/// T/ \ | |||
/// switch1 \ | |||
/// / \ \ | |||
/// T/ \. | |||
/// switch1 \. | |||
/// / \ \. | |||
/// var1 var2 var4 | |||
ComputeGraphPtr BuildGraph3() { | |||
auto builder = ut::GraphBuilder("g3"); | |||
@@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { | |||
/// netoutput1 | |||
/// | | |||
/// merge1 | |||
/// / \ | |||
/// / \. | |||
/// add1 addn1 | |||
/// / \T F/ | |||
/// var3 swtich2 | |||
@@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { | |||
} | |||
/// ----> netoutput1 | |||
/// / | \ | |||
/// / | \. | |||
/// transdata1 transdata2 transdata3 | |||
/// \ / | | |||
/// var1-------------- | |||
@@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { | |||
} | |||
/// ---------> netoutput1 | |||
/// / | \ | |||
/// / | \. | |||
/// transdata1 transdata2(l1) transdata3(l1) | |||
/// \ / | | |||
/// var1------------------ | |||
@@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||
/// -->transpose1 -->transpose3-->sinh2 | |||
/// | \ / | |||
/// | -->transpose2 | |||
/// | \ | |||
/// | \. | |||
/// / -->cast3-->cast4-->sinh3 | |||
/// / | |||
/// / -->transpose4-->transpose5-->sinh4 | |||
/// / / | |||
/// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | |||
/// \ \ | |||
/// \ \. | |||
/// \ -->sinh6 | |||
/// \ | |||
/// \. | |||
/// \ -->transpose6-->transpose7-->sinh9 | |||
/// \ / | |||
/// -->reshape-->cast6-->cast7-->sinh8 | |||
/// \ | |||
/// \. | |||
/// -->sinh7 | |||
/// after optimized graph | |||
@@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||
/// / /-->transpose3-->sinh2 | |||
/// -->Cast1 | |||
/// / \-->sinh7 | |||
/// / \ | |||
/// / \. | |||
/// / -->sinh9 | |||
/// Node4D | |||
/// \ -->sinh4 | |||
/// \ / | |||
/// -->Cast5-->sinh5 | |||
/// \ \ | |||
/// \ \. | |||
/// \ -->sinh6 | |||
/// \ | |||
/// \. | |||
/// -->Cast7-->sinh8 | |||
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
@@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran | |||
/// TransData TransData ... MatMul ... | |||
/// \ | / / / | |||
/// HcomAllReduce | |||
/// / | \ \ \ | |||
/// / | \ \ \. | |||
/// TransData TransData ... RealDiv ... | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr allreduce = | |||
@@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans | |||
/// TransData TransData ... MatMul ... | |||
/// \ | / / / | |||
/// HcomAllReduce | |||
/// / | \ \ \ | |||
/// / | \ \ \. | |||
/// TransData TransData ... RealDiv ... | |||
size_t symmetric_transdata_num = 20; | |||
size_t asymmetric_transdata_num = 20; | |||
@@ -66,7 +66,7 @@ namespace { | |||
/// transdata2 | |||
/// | | |||
/// assign1 | |||
/// / \ | |||
/// / \. | |||
/// transdata1 | | |||
/// | | | |||
/// var1 const1 | |||
@@ -35,8 +35,8 @@ namespace { | |||
/// shapeNo1 | |||
/// | | |||
/// addnYes1 | |||
/// / \ | |||
/// / \ | |||
/// / \. | |||
/// / \. | |||
/// const1 const2 | |||
ComputeGraphPtr BuildGraph1() { | |||
@@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { | |||
/// | |||
/// netoutput1 | |||
/// / \ \ | |||
/// add1 assign1 \ | |||
/// / \ / \ \ | |||
/// / \ \. | |||
/// add1 assign1 \. | |||
/// / \ / \ \. | |||
/// var1 var2 const1 var3 | |||
ComputeGraphPtr BuildGraph2() { | |||