Merge pull request !1923 from yangyongqiang/infer_format_singletags/v1.5.1
@@ -808,7 +808,7 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { | |||||
Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) { | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { | if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { | ||||
auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); | auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); | ||||
@@ -832,7 +832,11 @@ Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { | |||||
} | } | ||||
node_op.BreakConnect(); | node_op.BreakConnect(); | ||||
} | } | ||||
auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||||
auto comp_graph = GraphUtils::GetComputeGraph(graph); | |||||
GE_CHECK_NOTNULL(comp_graph); | |||||
auto node = comp_graph->FindNode(op_desc->GetName()); | |||||
GE_CHECK_NOTNULL(node); | |||||
auto op = OpDescUtils::CreateOperatorFromNode(node); | |||||
auto ret = op_desc->CallInferFormatFunc(op); | auto ret = op_desc->CallInferFormatFunc(op); | ||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", | REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", | ||||
@@ -879,7 +883,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
Graph graph; | Graph graph; | ||||
GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | ||||
"[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | ||||
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); | |||||
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph)); | |||||
// 2. check engine type when compile online | // 2. check engine type when compile online | ||||
if (model_file_name == kFileNameSuffix) { | if (model_file_name == kFileNameSuffix) { | ||||
@@ -106,7 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator { | |||||
bool CheckNoAicore(const ComputeGraphPtr &graph); | bool CheckNoAicore(const ComputeGraphPtr &graph); | ||||
void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs); | void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs); | ||||
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | ||||
Status InferFormatForSingleOp(OpDescPtr &op_desc); | |||||
Status InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph); | |||||
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); | Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); | ||||
@@ -83,12 +83,16 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { | |||||
graphStatus TestFunc(Operator &op) { return 0; } | graphStatus TestFunc(Operator &op) { return 0; } | ||||
graphStatus TestFunc1(Operator &op) { return 1; } | graphStatus TestFunc1(Operator &op) { return 1; } | ||||
TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | ||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("graph_name"); | |||||
auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); | OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); | ||||
shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add"); | shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add"); | ||||
compute_graph->AddNode(op_desc); | |||||
GeGenerator generator; | GeGenerator generator; | ||||
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); | |||||
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc, graph), SUCCESS); | |||||
shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add"); | shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add"); | ||||
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); | |||||
compute_graph->AddNode(op_desc1); | |||||
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1, graph), SUCCESS); | |||||
OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); | OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); | ||||
shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2"); | shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2"); | ||||
GeTensorDesc tensor_desc; | GeTensorDesc tensor_desc; | ||||
@@ -99,7 +103,8 @@ TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | |||||
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | ||||
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | ||||
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | ||||
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); | |||||
compute_graph->AddNode(op_desc2); | |||||
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2, graph), FAILED); | |||||
} | } | ||||
TEST_F(UtestGeGenerator, test_build_single_op_online) { | TEST_F(UtestGeGenerator, test_build_single_op_online) { | ||||