| @@ -808,7 +808,7 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { | |||||
| Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) { | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { | if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { | ||||
| auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); | auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); | ||||
| @@ -832,7 +832,11 @@ Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { | |||||
| } | } | ||||
| node_op.BreakConnect(); | node_op.BreakConnect(); | ||||
| } | } | ||||
| auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||||
| auto comp_graph = GraphUtils::GetComputeGraph(graph); | |||||
| GE_CHECK_NOTNULL(comp_graph); | |||||
| auto node = comp_graph->FindNode(op_desc->GetName()); | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op = OpDescUtils::CreateOperatorFromNode(node); | |||||
| auto ret = op_desc->CallInferFormatFunc(op); | auto ret = op_desc->CallInferFormatFunc(op); | ||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", | REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", | ||||
| @@ -879,7 +883,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| Graph graph; | Graph graph; | ||||
| GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | ||||
| "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | ||||
| GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); | |||||
| GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph)); | |||||
| // 2. check engine type when compile online | // 2. check engine type when compile online | ||||
| if (model_file_name == kFileNameSuffix) { | if (model_file_name == kFileNameSuffix) { | ||||
| @@ -106,7 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator { | |||||
| bool CheckNoAicore(const ComputeGraphPtr &graph); | bool CheckNoAicore(const ComputeGraphPtr &graph); | ||||
| void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs); | void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs); | ||||
| Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | ||||
| Status InferFormatForSingleOp(OpDescPtr &op_desc); | |||||
| Status InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph); | |||||
| using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
| Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); | Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); | ||||
| @@ -83,12 +83,16 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { | |||||
| graphStatus TestFunc(Operator &op) { return 0; } | graphStatus TestFunc(Operator &op) { return 0; } | ||||
| graphStatus TestFunc1(Operator &op) { return 1; } | graphStatus TestFunc1(Operator &op) { return 1; } | ||||
| TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | ||||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("graph_name"); | |||||
| auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); | OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); | ||||
| shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add"); | shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add"); | ||||
| compute_graph->AddNode(op_desc); | |||||
| GeGenerator generator; | GeGenerator generator; | ||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); | |||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc, graph), SUCCESS); | |||||
| shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add"); | shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add"); | ||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); | |||||
| compute_graph->AddNode(op_desc1); | |||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1, graph), SUCCESS); | |||||
| OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); | OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); | ||||
| shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2"); | shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2"); | ||||
| GeTensorDesc tensor_desc; | GeTensorDesc tensor_desc; | ||||
| @@ -99,7 +103,8 @@ TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | |||||
| EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | ||||
| EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | ||||
| EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | ||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); | |||||
| compute_graph->AddNode(op_desc2); | |||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2, graph), FAILED); | |||||
| } | } | ||||
| TEST_F(UtestGeGenerator, test_build_single_op_online) { | TEST_F(UtestGeGenerator, test_build_single_op_online) { | ||||