From 374ff33278efe0511ec3a8c417de1f04c57cadab Mon Sep 17 00:00:00 2001 From: yskhhh Date: Mon, 10 May 2021 14:25:41 +0800 Subject: [PATCH] ake data op type be acceptable --- ge/graph/preprocess/graph_preprocess.cc | 3 +- .../preprocess/graph_preprocess_unittest.cc | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 7b761cd0..4e9046e4 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1257,7 +1257,8 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i // Since ME dont differentiate between RefSwitch and Switch, and only issue Switch. static std::set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, - ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; + ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH, + ge::DATA}; GE_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index 8d0be31d..6c5babfc 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -72,6 +72,19 @@ ComputeGraphPtr BuildGraph3() { return builder.GetGraph(); } +ComputeGraphPtr BuildGraph5() { + auto builder = ut::GraphBuilder("g5"); + auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); + auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); + auto add = builder.AddNode("add", ADD, 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, add, 0); + builder.AddDataEdge(data2, 0, add, 1); + builder.AddDataEdge(add, 0,netoutput, 0); + return builder.GetGraph(); +} + /* * MapIndex Data1 subgraph1 subgraph2 * \ / @@ -154,6 +167,23 @@ TEST_F(UtestGraphPreproces, test_update_input_output1) { EXPECT_EQ(ret, SUCCESS); } + +TEST_F(UtestGraphPreproces, check_ref_op_data_succ) { + GraphPrepare graph_preparer; + ComputeGraphPtr graph_test = BuildGraph5(); + NodePtr add_node = nullptr; + for (auto &node : graph_test->GetAllNodes()) { + if (node->GetName() == "add") { + add_node = node; + } + } + EXPECT_NE(add_node, nullptr); + string input_name = "__input0"; + std::set ref_nodes; + auto ret = graph_preparer.CheckRefInputNode(add_node, input_name, ref_nodes); + EXPECT_EQ(ret, SUCCESS); +} + TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) { ge::GraphPrepare graph_prepare; graph_prepare.compute_graph_ = BuildGraph4();