Browse Source

ake data op type be acceptable

tags/v1.3.0
yskhhh 3 years ago
parent
commit
374ff33278
2 changed files with 32 additions and 1 deletions
  1. +2
    -1
      ge/graph/preprocess/graph_preprocess.cc
  2. +30
    -0
      tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc

+ 2
- 1
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1257,7 +1257,8 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i
// Since ME dont differentiate between RefSwitch and Switch, and only issue Switch. // Since ME dont differentiate between RefSwitch and Switch, and only issue Switch.
static std::set<std::string> acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, static std::set<std::string> acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP,
ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, ge::REFSWITCH, ge::REFMERGE, ge::REFENTER,
ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH};
ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH,
ge::DATA};
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
const auto &op_desc = node->GetOpDesc(); const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);


+ 30
- 0
tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc View File

@@ -72,6 +72,19 @@ ComputeGraphPtr BuildGraph3() {
return builder.GetGraph(); return builder.GetGraph();
} }


ComputeGraphPtr BuildGraph5() {
auto builder = ut::GraphBuilder("g5");
auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
auto add = builder.AddNode("add", ADD, 2, 1);
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

builder.AddDataEdge(data1, 0, add, 0);
builder.AddDataEdge(data2, 0, add, 1);
builder.AddDataEdge(add, 0,netoutput, 0);
return builder.GetGraph();
}

/* /*
* MapIndex Data1 subgraph1 subgraph2 * MapIndex Data1 subgraph1 subgraph2
* \ / * \ /
@@ -154,6 +167,23 @@ TEST_F(UtestGraphPreproces, test_update_input_output1) {
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }



TEST_F(UtestGraphPreproces, check_ref_op_data_succ) {
GraphPrepare graph_preparer;
ComputeGraphPtr graph_test = BuildGraph5();
NodePtr add_node = nullptr;
for (auto &node : graph_test->GetAllNodes()) {
if (node->GetName() == "add") {
add_node = node;
}
}
EXPECT_NE(add_node, nullptr);
string input_name = "__input0";
std::set<NodePtr> ref_nodes;
auto ret = graph_preparer.CheckRefInputNode(add_node, input_name, ref_nodes);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) { TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) {
ge::GraphPrepare graph_prepare; ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph4(); graph_prepare.compute_graph_ = BuildGraph4();


Loading…
Cancel
Save