Browse Source

bugfix for InferFormatForSingleOp

tags/v1.5.1
y00500818 3 years ago
parent
commit
ebf39e513d
3 changed files with 16 additions and 7 deletions
  1. +7
    -3
      ge/generator/ge_generator.cc
  2. +1
    -1
      inc/framework/generator/ge_generator.h
  3. +8
    -3
      tests/ut/ge/generator/ge_generator_unittest.cc

+ 7
- 3
ge/generator/ge_generator.cc View File

@@ -808,7 +808,7 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor>
return SUCCESS;
}

Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) {
Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) {
GE_CHECK_NOTNULL(op_desc);
if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) {
auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType());
@@ -832,7 +832,11 @@ Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) {
}
node_op.BreakConnect();
}
auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
auto comp_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(comp_graph);
auto node = comp_graph->FindNode(op_desc->GetName());
GE_CHECK_NOTNULL(node);
auto op = OpDescUtils::CreateOperatorFromNode(node);
auto ret = op_desc->CallInferFormatFunc(op);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail",
@@ -879,7 +883,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
Graph graph;
GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph),
"[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str());
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc));
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph));

// 2. check engine type when compile online
if (model_file_name == kFileNameSuffix) {


+ 1
- 1
inc/framework/generator/ge_generator.h View File

@@ -106,7 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator {
bool CheckNoAicore(const ComputeGraphPtr &graph);
void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs);
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
Status InferFormatForSingleOp(OpDescPtr &op_desc);
Status InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph);

using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;
Status SetModelNameForDump(const GeRootModelPtr &ge_root_model);


+ 8
- 3
tests/ut/ge/generator/ge_generator_unittest.cc View File

@@ -83,12 +83,16 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) {
graphStatus TestFunc(Operator &op) { return 0; }
graphStatus TestFunc1(Operator &op) { return 1; }
TEST_F(UtestGeGenerator, test_infer_format_for_single_op) {
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("graph_name");
auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc);
shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add");
compute_graph->AddNode(op_desc);
GeGenerator generator;
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS);
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc, graph), SUCCESS);
shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add");
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS);
compute_graph->AddNode(op_desc1);
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1, graph), SUCCESS);
OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1);
shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2");
GeTensorDesc tensor_desc;
@@ -99,7 +103,8 @@ TEST_F(UtestGeGenerator, test_infer_format_for_single_op) {
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED);
compute_graph->AddNode(op_desc2);
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2, graph), FAILED);
}

TEST_F(UtestGeGenerator, test_build_single_op_online) {


Loading…
Cancel
Save