diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 505b1908..07355ab5 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -808,7 +808,7 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector return SUCCESS; } -Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { +Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) { GE_CHECK_NOTNULL(op_desc); if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); @@ -832,7 +832,11 @@ Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { } node_op.BreakConnect(); } - auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + auto comp_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(comp_graph); + auto node = comp_graph->FindNode(op_desc->GetName()); + GE_CHECK_NOTNULL(node); + auto op = OpDescUtils::CreateOperatorFromNode(node); auto ret = op_desc->CallInferFormatFunc(op); if (ret != GRAPH_SUCCESS) { REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", @@ -879,7 +883,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in Graph graph; GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); + GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph)); // 2. check engine type when compile online if (model_file_name == kFileNameSuffix) { diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index ee51d29d..5da5a593 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -106,7 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator { bool CheckNoAicore(const ComputeGraphPtr &graph); void RemoveConst(const vector &inputs, vector &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); - Status InferFormatForSingleOp(OpDescPtr &op_desc); + Status InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph); using GeRootModelPtr = std::shared_ptr; Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 1bb4430f..b3abb2f9 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -83,12 +83,16 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { graphStatus TestFunc(Operator &op) { return 0; } graphStatus TestFunc1(Operator &op) { return 1; } TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { + ComputeGraphPtr compute_graph = MakeShared("graph_name"); + auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); shared_ptr op_desc = make_shared("add", "add"); + compute_graph->AddNode(op_desc); GeGenerator generator; - EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc, graph), SUCCESS); shared_ptr op_desc1 = make_shared("Add", "Add"); - EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); + compute_graph->AddNode(op_desc1); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1, graph), SUCCESS); OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); shared_ptr op_desc2 = make_shared("MatMulV2", "MatMulV2"); GeTensorDesc tensor_desc; @@ -99,7 +103,8 @@ TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); - EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); + compute_graph->AddNode(op_desc2); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2, graph), FAILED); } TEST_F(UtestGeGenerator, test_build_single_op_online) {