| @@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) | |||||
| INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | ||||
| { | { | ||||
| const char *env = getenv(name); | |||||
| if (env != nullptr) { | |||||
| strcpy(value, env); | |||||
| } | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -726,7 +726,6 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/memcpy_addr_async_unittest.cc" | "graph/passes/memcpy_addr_async_unittest.cc" | ||||
| "graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
| "graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
| ) | ) | ||||
| set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
| @@ -859,7 +858,6 @@ set(HYBRID_TEST_FILES | |||||
| "hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
| "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
| "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
| ) | ) | ||||
| set(OTHERS_TEST_FILES | set(OTHERS_TEST_FILES | ||||
| @@ -887,6 +885,7 @@ add_library(ge_ut_graph STATIC | |||||
| target_compile_definitions(ge_ut_graph PRIVATE | target_compile_definitions(ge_ut_graph PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| FMK_SUPPORT_DUMP | |||||
| ) | ) | ||||
| target_compile_options(ge_ut_graph PRIVATE | target_compile_options(ge_ut_graph PRIVATE | ||||
| @@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
| /// B --> C(AllReduce) --- D | /// B --> C(AllReduce) --- D | ||||
| /// / | /// / | ||||
| /// stream id: 0 A | /// stream id: 0 A | ||||
| /// \ | |||||
| /// \. | |||||
| /// E --> F(AllReduce) --- G | /// E --> F(AllReduce) --- G | ||||
| /// stream id: 2 2 2 | /// stream id: 2 2 2 | ||||
| /// | /// | ||||
| @@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { | |||||
| /// case of multi-output, then unuse stream | /// case of multi-output, then unuse stream | ||||
| /// sub1 | /// sub1 | ||||
| /// / | \ | |||||
| /// / | \. | |||||
| /// sub2 sub3 sub4 | /// sub2 sub3 sub4 | ||||
| TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | ||||
| SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
| @@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||||
| /// if paralle id 1, then use stream | /// if paralle id 1, then use stream | ||||
| /// sub1 | /// sub1 | ||||
| /// / | | \ | |||||
| /// / | | \. | |||||
| /// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
| TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | ||||
| SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
| @@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||||
| /// if the param of engine independent is true, then set independent stream | /// if the param of engine independent is true, then set independent stream | ||||
| /// sub1 | /// sub1 | ||||
| /// / | | \ | |||||
| /// / | | \. | |||||
| /// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
| TEST_F(UtestLogicalStreamAllocator, test_independent) { | TEST_F(UtestLogicalStreamAllocator, test_independent) { | ||||
| SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
| @@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||||
| /// set stream based on stream label, and then based on independent | /// set stream based on stream label, and then based on independent | ||||
| /// sub1 | /// sub1 | ||||
| /// / | | \ | |||||
| /// / | | \. | |||||
| /// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
| TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | ||||
| SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
| @@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { | |||||
| /// | /// | ||||
| /// A | /// A | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// B C | /// B C | ||||
| /// | | | /// | | | ||||
| /// D 400 | /// D 400 | ||||
| @@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { | |||||
| }; | }; | ||||
| /// D E | /// D E | ||||
| /// | \ | \ | |||||
| /// | \ | \. | |||||
| /// F C G | /// F C G | ||||
| /// : | : | /// : | : | ||||
| /// H A I | /// H A I | ||||
| @@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { | |||||
| EXPECT_EQ(graph->FindNode("D"), nullptr); | EXPECT_EQ(graph->FindNode("D"), nullptr); | ||||
| } | } | ||||
| /// E F | |||||
| /// | \ | \ | |||||
| /// E F | |||||
| /// | \ | \. | |||||
| /// H C -> D G | /// H C -> D G | ||||
| /// \ | : | /// \ | : | ||||
| /// A I | /// A I | ||||
| @@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { | |||||
| /// reshape1 | /// reshape1 | ||||
| /// | | /// | | ||||
| /// add1 | /// add1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// | | | /// | | | ||||
| /// data1 const1 | /// data1 const1 | ||||
| ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
| @@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
| } | } | ||||
| /// sum1 | /// sum1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// reshape1 addn1 | /// reshape1 addn1 | ||||
| /// | c | | /// | c | | ||||
| /// add1 <--- shape1 | /// add1 <--- shape1 | ||||
| @@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str | |||||
| /// Op1 | /// Op1 | ||||
| /// | | /// | | ||||
| /// Merge | /// Merge | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op2 Op3 | /// Op2 Op3 | ||||
| TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | ||||
| auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
| @@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||||
| /// Op1 | /// Op1 | ||||
| /// | | /// | | ||||
| /// Merge | /// Merge | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op2 Op3 | /// Op2 Op3 | ||||
| TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | ||||
| auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
| @@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||||
| /// data1 const | /// data1 const | ||||
| /// \ / | /// \ / | ||||
| /// while | /// while | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// | | | /// | | | ||||
| /// cast1 cast2 | /// cast1 cast2 | ||||
| ComputeGraphPtr BuildWhileGraph1() { | ComputeGraphPtr BuildWhileGraph1() { | ||||
| @@ -34,11 +34,11 @@ namespace { | |||||
| /// net_output | /// net_output | ||||
| /// | | /// | | ||||
| /// merge | /// merge | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// square add | /// square add | ||||
| /// F| T/ T\ | |||||
| /// F| T/ T\. | |||||
| /// switch1 switch2 | /// switch1 switch2 | ||||
| /// / \ / \ | |||||
| /// / \ / \. | |||||
| /// var1 var2 var3 | /// var1 var2 var3 | ||||
| /// | /// | ||||
| ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
| @@ -173,8 +173,8 @@ namespace { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | | /// | | ||||
| /// addnYes1 | /// addnYes1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 const2 | /// const1 const2 | ||||
| ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | c | /// | c | ||||
| /// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 const2 | /// const1 const2 | ||||
| ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | c | /// | c | ||||
| /// addnYes1 <--------- | /// addnYes1 <--------- | ||||
| /// / \ \ | |||||
| /// / \ c \ | |||||
| /// / \ \. | |||||
| /// / \ c \. | |||||
| /// const1 const2 <----- dataNo1 | /// const1 const2 <----- dataNo1 | ||||
| ComputeGraphPtr BuildGraph4() { | ComputeGraphPtr BuildGraph4() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | c | /// | c | ||||
| /// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \ c | /// / \ c | ||||
| /// const1 const2 <----- dataNo2 | /// const1 const2 <----- dataNo2 | ||||
| ComputeGraphPtr BuildGraph5() { | ComputeGraphPtr BuildGraph5() { | ||||
| @@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { | |||||
| /// addYes1 <---- const3 | /// addYes1 <---- const3 | ||||
| /// | | /// | | ||||
| /// addnYes1 <- | /// addnYes1 <- | ||||
| /// / \ \ | |||||
| /// / \ \ | |||||
| /// / \ \. | |||||
| /// / \ \. | |||||
| /// const1 const2 const4 | /// const1 const2 const4 | ||||
| ComputeGraphPtr BuildGraph6() { | ComputeGraphPtr BuildGraph6() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { | |||||
| } | } | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// shapeNo1 ShpaeNo2 | /// shapeNo1 ShpaeNo2 | ||||
| /// \ / | /// \ / | ||||
| /// huberLoss1 | /// huberLoss1 | ||||
| /// / | \ | |||||
| /// / | \ | |||||
| /// / | \. | |||||
| /// / | \. | |||||
| /// const1 const2 const3 | /// const1 const2 const3 | ||||
| ComputeGraphPtr BuildGraph7() { | ComputeGraphPtr BuildGraph7() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | | /// | | ||||
| /// addnNo1 | /// addnNo1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 const2 | /// const1 const2 | ||||
| ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | | /// | | ||||
| /// addnYes1 | /// addnYes1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 data1 | /// const1 data1 | ||||
| ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { | |||||
| } | } | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// addDim sqrt1 | /// addDim sqrt1 | ||||
| /// \ / | /// \ / | ||||
| /// switch1 | /// switch1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 const2 | /// const1 const2 | ||||
| ComputeGraphPtr BuildGraph10() { | ComputeGraphPtr BuildGraph10() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -63,8 +63,8 @@ namespace { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | | /// | | ||||
| /// addnNo1 | /// addnNo1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 const2 | /// const1 const2 | ||||
| ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | | /// | | ||||
| /// addnYes1 | /// addnYes1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| ///const1 data1 | ///const1 data1 | ||||
| ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
| auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
| @@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test { | |||||
| /// convolution data | /// convolution data | ||||
| /// | / | /// | / | ||||
| /// ssdpriorbox | /// ssdpriorbox | ||||
| /// \ | |||||
| /// \. | |||||
| /// reshape | /// reshape | ||||
| class NodeBuilder { | class NodeBuilder { | ||||
| public: | public: | ||||
| @@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||||
| /// graph with subgraph | /// graph with subgraph | ||||
| /// const | /// const | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// cast1 cast1 | /// cast1 cast1 | ||||
| /// \ / | /// \ / | ||||
| /// case | /// case | ||||
| @@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||||
| return graph.AddNode(op_desc); | return graph.AddNode(op_desc); | ||||
| } | } | ||||
| static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
| static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) { | |||||
| /******************************************************************************* | /******************************************************************************* | ||||
| * Exit Identify | |||||
| * \ / \. | |||||
| * \ / \. | |||||
| * Switch Add | |||||
| * / | | | |||||
| * / | | | |||||
| * / | | | |||||
| * LoopCond | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Less | | | |||||
| * \ | NextIteration | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Merge <---------| | |||||
| * | | |||||
| * | | |||||
| * Enter | |||||
| * | | |||||
| * +--------------------- Merge ----------------------+ | |||||
| * / | | |||||
| * / | | |||||
| * / | | |||||
| * / | | |||||
| * Exit Identify | | |||||
| * \ / \. | | |||||
| * \ / \. | | |||||
| * Switch Add Add | |||||
| * / | | | | |||||
| * / | | | | |||||
| * / | | | | |||||
| * LoopCond | | | | |||||
| * \ | | | | |||||
| * \ | | | | |||||
| * \ | | | | |||||
| * Less | | | | |||||
| * \ | NextIteration | | |||||
| * \ | | | | |||||
| * \ | | | | |||||
| * Merge <---------| | | |||||
| * | | | |||||
| * | | | |||||
| * Enter | | |||||
| * \ | | |||||
| * \ | | |||||
| * Switch Switch | |||||
| * | | | |||||
| * +-----------------Equal----------------------+ | |||||
| * | | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
| auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||||
| auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); | |||||
| auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); | |||||
| auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); | |||||
| auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); | |||||
| auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | ||||
| auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
| auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); | |||||
| auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); | |||||
| auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | ||||
| auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
| auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); | |||||
| auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | ||||
| auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
| auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); | |||||
| auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | ||||
| auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | ||||
| auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||||
| auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||||
| auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); | |||||
| auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); | |||||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | ||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); | |||||
| cond.emplace_back(switch1); | |||||
| cond.emplace_back(switch2); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false | |||||
| GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | ||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | ||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | ||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | ||||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); | |||||
| loop.emplace_back(merge1); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false | |||||
| GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true | |||||
| loop.emplace_back(switch3); | |||||
| GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | ||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | ||||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | ||||
| GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | ||||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| merge = merge1; | |||||
| GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true | |||||
| GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| cond.emplace_back(merge2); | |||||
| merge = merge2; | |||||
| } | } | ||||
| static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | ||||
| @@ -197,12 +235,27 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
| TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | ||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
| NodePtr merge; | NodePtr merge; | ||||
| CreateLoopGraph(graph, merge); | |||||
| AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
| vector<NodePtr> loop; | |||||
| vector<NodePtr> cond; | |||||
| CreateLoopGraph(graph, merge, loop, cond); | |||||
| MarkForceUnknownForCondPass mark_force_unknown_pass; | MarkForceUnknownForCondPass mark_force_unknown_pass; | ||||
| EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | ||||
| setenv("DUMP_GE_GRAPH", "1", true); | |||||
| GE_DUMP(graph, "control_group"); | |||||
| unsetenv("DUMP_GE_GRAPH"); | |||||
| EXPECT_EQ(loop.size(), 2); | |||||
| for (const auto &node : loop) { | |||||
| EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); | |||||
| } | |||||
| EXPECT_EQ(cond.size(), 3); | |||||
| for (const auto &node : cond) { | |||||
| int64_t group_index = -1; | |||||
| EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); | |||||
| EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); | |||||
| } | |||||
| } | } | ||||
| TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | ||||
| @@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { | |||||
| } | } | ||||
| /// Merge | /// Merge | ||||
| /// | \ | |||||
| /// | \ | |||||
| /// | \. | |||||
| /// | \. | |||||
| /// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
| /// \ | | | /// \ | | | ||||
| /// \ | Op3 | /// \ | Op3 | ||||
| @@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da | |||||
| } | } | ||||
| /// Merge | /// Merge | ||||
| /// | \ | |||||
| /// | \ | |||||
| /// | \. | |||||
| /// | \. | |||||
| /// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
| /// \ | | \ | |||||
| /// \ | | \. | |||||
| /// \ | Op3 | /// \ | Op3 | ||||
| /// \ | : | /// \ | : | ||||
| /// NetOutput | /// NetOutput | ||||
| @@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co | |||||
| TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | ||||
| /// Merge | /// Merge | ||||
| /// | \ | |||||
| /// | \ | |||||
| /// | \. | |||||
| /// | \. | |||||
| /// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
| /// \ | | | /// \ | | | ||||
| /// \ | Op3 | /// \ | Op3 | ||||
| @@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||||
| /// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
| /// \ | | /// \ | | ||||
| /// \ Op3 | /// \ Op3 | ||||
| /// \ | |||||
| /// \. | |||||
| /// Merge3 | /// Merge3 | ||||
| ret = pass_.Run(merge_node2); | ret = pass_.Run(merge_node2); | ||||
| @@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { | |||||
| /// Op1 | /// Op1 | ||||
| /// | | /// | | ||||
| /// Merge | /// Merge | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op2 Op3 | /// Op2 Op3 | ||||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
| auto node1 = NewNode("Op1", RELU, 1, 1); | auto node1 = NewNode("Op1", RELU, 1, 1); | ||||
| @@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { | |||||
| /// Const | /// Const | ||||
| /// | | /// | | ||||
| /// Merge Pass Const | /// Merge Pass Const | ||||
| /// / \ ===> / \ | |||||
| /// / \ ===> / \. | |||||
| /// Op1 Op2 Op1 Op2 | /// Op1 Op2 Op1 Op2 | ||||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
| auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
| @@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) | |||||
| /// / | ===> / \(control anchor) | /// / | ===> / \(control anchor) | ||||
| /// Op1 | \ Op1 Constant | /// Op1 | \ Op1 Constant | ||||
| /// Op2 Op3 | | /// Op2 Op3 | | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op2 Op3 | /// Op2 Op3 | ||||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
| auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
| @@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) | |||||
| /// / | ===> / \(control anchor) | /// / | ===> / \(control anchor) | ||||
| /// Op1 | \ Op1 Constant | /// Op1 | \ Op1 Constant | ||||
| /// Op2 Op3 | | /// Op2 Op3 | | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op2 Op3 | /// Op2 Op3 | ||||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
| auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
| @@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||||
| /// C | /// C | ||||
| /// | | /// | | ||||
| /// Merge | /// Merge | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op1 Op2 | /// Op1 Op2 | ||||
| auto switch_node = NewNode("Switch", SWITCH, 1, 2); | auto switch_node = NewNode("Switch", SWITCH, 1, 2); | ||||
| auto identity_node = NewNode("Identity", SWITCH, 1, 1); | auto identity_node = NewNode("Identity", SWITCH, 1, 1); | ||||
| @@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||||
| /// . | /// . | ||||
| /// . | /// . | ||||
| /// C | /// C | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// Op1 Op2 | /// Op1 Op2 | ||||
| auto ret = pass_.Run(merge_node); | auto ret = pass_.Run(merge_node); | ||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| @@ -66,11 +66,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
| void BuildDefaultGraph() { | void BuildDefaultGraph() { | ||||
| /// input | /// input | ||||
| /// \ | |||||
| /// \. | |||||
| /// sqrt pred | /// sqrt pred | ||||
| /// \ / | /// \ / | ||||
| /// cast | /// cast | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// switch_t switch_f | /// switch_t switch_f | ||||
| /// | | | /// | | | ||||
| /// F T | /// F T | ||||
| @@ -118,13 +118,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
| void BuildDefaultGraph1() { | void BuildDefaultGraph1() { | ||||
| /// input | /// input | ||||
| /// \ | |||||
| /// \. | |||||
| /// sqrt pred | /// sqrt pred | ||||
| /// \ / | /// \ / | ||||
| /// Switch | /// Switch | ||||
| /// | | | /// | | | ||||
| /// ----F T---- | /// ----F T---- | ||||
| /// \ | / \ | |||||
| /// \ | / \. | |||||
| /// \ Merge1 Merge2 | /// \ Merge1 Merge2 | ||||
| /// \_________| | /// \_________| | ||||
| input_node_ = NewNode("input", RELU, 0, 1); | input_node_ = NewNode("input", RELU, 0, 1); | ||||
| @@ -164,14 +164,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
| void BuildDefaultGraph2() { | void BuildDefaultGraph2() { | ||||
| /// input input1 | /// input input1 | ||||
| /// \ \ | |||||
| /// \ \. | |||||
| /// sqrt pred sqrt1 pred1 | /// sqrt pred sqrt1 pred1 | ||||
| /// \ / \ / | /// \ / \ / | ||||
| /// Switch Switch1 | /// Switch Switch1 | ||||
| /// | | _______| | /// | | _______| | ||||
| /// | | / | /// | | / | ||||
| /// ____F T____ | /// ____F T____ | ||||
| /// \ | / \ | |||||
| /// \ | / \. | |||||
| /// \ Merge1 Merge2 | /// \ Merge1 Merge2 | ||||
| /// \__________| | /// \__________| | ||||
| input_node_ = NewNode("input", RELU, 0, 2); | input_node_ = NewNode("input", RELU, 0, 2); | ||||
| @@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { | |||||
| namespace { | namespace { | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// | \ | |||||
| ///transdata1 \ | |||||
| /// | \ | |||||
| /// | \. | |||||
| ///transdata1 \. | |||||
| /// | \. | |||||
| /// | transdata2 | /// | transdata2 | ||||
| /// | / | /// | / | ||||
| /// var1 const1 | /// var1 const1 | ||||
| @@ -35,7 +35,7 @@ namespace { | |||||
| /// transdata1 | /// transdata1 | ||||
| /// | | /// | | ||||
| /// reshape1 | /// reshape1 | ||||
| /// | \ | |||||
| /// | \. | |||||
| /// var1 const1 | /// var1 const1 | ||||
| ut::GraphBuilder Graph1Builder() { | ut::GraphBuilder Graph1Builder() { | ||||
| ut::GraphBuilder builder = ut::GraphBuilder("g1"); | ut::GraphBuilder builder = ut::GraphBuilder("g1"); | ||||
| @@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { | |||||
| } | } | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// | \ | |||||
| ///transdata1 \ | |||||
| /// | \ | |||||
| /// | \. | |||||
| ///transdata1 \. | |||||
| /// | \. | |||||
| /// reshape1 reshape2 | /// reshape1 reshape2 | ||||
| /// | \ / \ | |||||
| /// | \ / \. | |||||
| /// var1 const1 var2 | /// var1 const1 var2 | ||||
| ut::GraphBuilder Graph2Builder() { | ut::GraphBuilder Graph2Builder() { | ||||
| ut::GraphBuilder builder = ut::GraphBuilder("g2"); | ut::GraphBuilder builder = ut::GraphBuilder("g2"); | ||||
| @@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { | |||||
| } | } | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// | \ | |||||
| ///transdata1 \ | |||||
| /// | \ | |||||
| /// | \. | |||||
| ///transdata1 \. | |||||
| /// | \. | |||||
| /// reshape1 transdata2 | /// reshape1 transdata2 | ||||
| /// | \ / | /// | \ / | ||||
| /// var1 const1 | /// var1 const1 | ||||
| @@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { | |||||
| namespace { | namespace { | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// | \ | |||||
| /// | \. | |||||
| /// StackPush StackPop | /// StackPush StackPop | ||||
| /// | | | /// | | | ||||
| /// var1 const1 | /// var1 const1 | ||||
| @@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
| /// netoutput1 | /// netoutput1 | ||||
| /// | | /// | | ||||
| /// merge1 | /// merge1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// / add1 | /// / add1 | ||||
| /// / F| \ | |||||
| /// / F| \. | |||||
| /// addn1 swtich2 var3 | /// addn1 swtich2 var3 | ||||
| /// \F T/ | | /// \F T/ | | ||||
| /// switch1 | | /// switch1 | | ||||
| @@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { | |||||
| /// add1 | /// add1 | ||||
| /// / \T | /// / \T | ||||
| /// var3 swtich2 | /// var3 swtich2 | ||||
| /// T/ \ | |||||
| /// switch1 \ | |||||
| /// / \ \ | |||||
| /// T/ \. | |||||
| /// switch1 \. | |||||
| /// / \ \. | |||||
| /// var1 var2 var4 | /// var1 var2 var4 | ||||
| ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
| auto builder = ut::GraphBuilder("g3"); | auto builder = ut::GraphBuilder("g3"); | ||||
| @@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { | |||||
| /// netoutput1 | /// netoutput1 | ||||
| /// | | /// | | ||||
| /// merge1 | /// merge1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// add1 addn1 | /// add1 addn1 | ||||
| /// / \T F/ | /// / \T F/ | ||||
| /// var3 swtich2 | /// var3 swtich2 | ||||
| @@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { | |||||
| } | } | ||||
| /// ----> netoutput1 | /// ----> netoutput1 | ||||
| /// / | \ | |||||
| /// / | \. | |||||
| /// transdata1 transdata2 transdata3 | /// transdata1 transdata2 transdata3 | ||||
| /// \ / | | /// \ / | | ||||
| /// var1-------------- | /// var1-------------- | ||||
| @@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { | |||||
| } | } | ||||
| /// ---------> netoutput1 | /// ---------> netoutput1 | ||||
| /// / | \ | |||||
| /// / | \. | |||||
| /// transdata1 transdata2(l1) transdata3(l1) | /// transdata1 transdata2(l1) transdata3(l1) | ||||
| /// \ / | | /// \ / | | ||||
| /// var1------------------ | /// var1------------------ | ||||
| @@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||||
| /// -->transpose1 -->transpose3-->sinh2 | /// -->transpose1 -->transpose3-->sinh2 | ||||
| /// | \ / | /// | \ / | ||||
| /// | -->transpose2 | /// | -->transpose2 | ||||
| /// | \ | |||||
| /// | \. | |||||
| /// / -->cast3-->cast4-->sinh3 | /// / -->cast3-->cast4-->sinh3 | ||||
| /// / | /// / | ||||
| /// / -->transpose4-->transpose5-->sinh4 | /// / -->transpose4-->transpose5-->sinh4 | ||||
| /// / / | /// / / | ||||
| /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | ||||
| /// \ \ | |||||
| /// \ \. | |||||
| /// \ -->sinh6 | /// \ -->sinh6 | ||||
| /// \ | |||||
| /// \. | |||||
| /// \ -->transpose6-->transpose7-->sinh9 | /// \ -->transpose6-->transpose7-->sinh9 | ||||
| /// \ / | /// \ / | ||||
| /// -->reshape-->cast6-->cast7-->sinh8 | /// -->reshape-->cast6-->cast7-->sinh8 | ||||
| /// \ | |||||
| /// \. | |||||
| /// -->sinh7 | /// -->sinh7 | ||||
| /// after optimized graph | /// after optimized graph | ||||
| @@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||||
| /// / /-->transpose3-->sinh2 | /// / /-->transpose3-->sinh2 | ||||
| /// -->Cast1 | /// -->Cast1 | ||||
| /// / \-->sinh7 | /// / \-->sinh7 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// / -->sinh9 | /// / -->sinh9 | ||||
| /// Node4D | /// Node4D | ||||
| /// \ -->sinh4 | /// \ -->sinh4 | ||||
| /// \ / | /// \ / | ||||
| /// -->Cast5-->sinh5 | /// -->Cast5-->sinh5 | ||||
| /// \ \ | |||||
| /// \ \. | |||||
| /// \ -->sinh6 | /// \ -->sinh6 | ||||
| /// \ | |||||
| /// \. | |||||
| /// -->Cast7-->sinh8 | /// -->Cast7-->sinh8 | ||||
| ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
| @@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran | |||||
| /// TransData TransData ... MatMul ... | /// TransData TransData ... MatMul ... | ||||
| /// \ | / / / | /// \ | / / / | ||||
| /// HcomAllReduce | /// HcomAllReduce | ||||
| /// / | \ \ \ | |||||
| /// / | \ \ \. | |||||
| /// TransData TransData ... RealDiv ... | /// TransData TransData ... RealDiv ... | ||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
| NodePtr allreduce = | NodePtr allreduce = | ||||
| @@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans | |||||
| /// TransData TransData ... MatMul ... | /// TransData TransData ... MatMul ... | ||||
| /// \ | / / / | /// \ | / / / | ||||
| /// HcomAllReduce | /// HcomAllReduce | ||||
| /// / | \ \ \ | |||||
| /// / | \ \ \. | |||||
| /// TransData TransData ... RealDiv ... | /// TransData TransData ... RealDiv ... | ||||
| size_t symmetric_transdata_num = 20; | size_t symmetric_transdata_num = 20; | ||||
| size_t asymmetric_transdata_num = 20; | size_t asymmetric_transdata_num = 20; | ||||
| @@ -66,7 +66,7 @@ namespace { | |||||
| /// transdata2 | /// transdata2 | ||||
| /// | | /// | | ||||
| /// assign1 | /// assign1 | ||||
| /// / \ | |||||
| /// / \. | |||||
| /// transdata1 | | /// transdata1 | | ||||
| /// | | | /// | | | ||||
| /// var1 const1 | /// var1 const1 | ||||
| @@ -35,8 +35,8 @@ namespace { | |||||
| /// shapeNo1 | /// shapeNo1 | ||||
| /// | | /// | | ||||
| /// addnYes1 | /// addnYes1 | ||||
| /// / \ | |||||
| /// / \ | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// const1 const2 | /// const1 const2 | ||||
| ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
| @@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
| /// | /// | ||||
| /// netoutput1 | /// netoutput1 | ||||
| /// / \ \ | |||||
| /// add1 assign1 \ | |||||
| /// / \ / \ \ | |||||
| /// / \ \. | |||||
| /// add1 assign1 \. | |||||
| /// / \ / \ \. | |||||
| /// var1 var2 const1 var3 | /// var1 var2 const1 var3 | ||||
| ComputeGraphPtr BuildGraph2() { | ComputeGraphPtr BuildGraph2() { | ||||