diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index fa6d8fa8..8a94aa9b 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -31,6 +31,7 @@ #include "graph/ge_context.h" #include "graph/manager/graph_manager.h" #include "graph/manager/util/rt_context_util.h" +#include "graph/operator_factory_impl.h" #include "graph/opsproto_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" @@ -803,6 +804,41 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector return SUCCESS; } +Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc); + if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { + auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str()); + } else { + GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + if (temp_op_desc == nullptr) { + REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s", + op_desc->GetType().c_str()); + GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str()); + return FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferFormatForSingleOp UpdateInputName failed"); + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferFormatForSingleOp UpdateOutputName failed"); + } + } + node_op.BreakConnect(); + } + auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + auto ret = op_desc->CallInferFormatFunc(op); + if (ret != GRAPH_SUCCESS) { + REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", + op_desc->GetName().c_str()); + GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline, int32_t compile_flag) { @@ -843,9 +879,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in Graph graph; GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); - auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - GE_CHK_STATUS_RET(op_desc->CallInferFormatFunc(op), - "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str()); + GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); // 2. check engine type when compile online if (model_file_name == kFileNameSuffix) { diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 24f969dd..ee51d29d 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -106,6 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator { bool CheckNoAicore(const ComputeGraphPtr &graph); void RemoveConst(const vector &inputs, vector &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); + Status InferFormatForSingleOp(OpDescPtr &op_desc); using GeRootModelPtr = std::shared_ptr; Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index fb256c7c..1bb4430f 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -23,6 +23,7 @@ #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +#include "graph/operator_factory_impl.h" #include "../graph/passes/graph_builder_utils.h" #include "../graph/manager/graph_manager.h" #include "all_ops.h" @@ -79,6 +80,27 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); } */ +graphStatus TestFunc(Operator &op) { return 0; } +graphStatus TestFunc1(Operator &op) { return 1; } +TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { + OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); + shared_ptr op_desc = make_shared("add", "add"); + GeGenerator generator; + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); + shared_ptr op_desc1 = make_shared("Add", "Add"); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); + OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); + shared_ptr op_desc2 = make_shared("MatMulV2", "MatMulV2"); + GeTensorDesc tensor_desc; + EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); +} TEST_F(UtestGeGenerator, test_build_single_op_online) { GeTensorDesc tensor_desc; diff --git a/tests/ut/ge/graph/ops_stub.h b/tests/ut/ge/graph/ops_stub.h index 2a71d80a..c122befa 100644 --- a/tests/ut/ge/graph/ops_stub.h +++ b/tests/ut/ge/graph/ops_stub.h @@ -144,6 +144,17 @@ REG_OP(Data) DT_UINT64, DT_BOOL, DT_DOUBLE})) .OP_END_FACTORY_REG(GuaranteeConst) + REG_OP(MatMulV2) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .ATTR(transpose_x1, Bool, false) + .ATTR(transpose_x2, Bool, false) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(MatMulV2) + IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) { TensorDesc tensorDesc = op.GetInputDesc("x"); (void)op.UpdateOutputDesc("y", tensorDesc);