| @@ -31,6 +31,7 @@ | |||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/manager/graph_manager.h" | #include "graph/manager/graph_manager.h" | ||||
| #include "graph/manager/util/rt_context_util.h" | #include "graph/manager/util/rt_context_util.h" | ||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -803,6 +804,41 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { | |||||
| auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); | |||||
| if (node_op.IsEmpty()) { | |||||
| GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str()); | |||||
| } else { | |||||
| GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str()); | |||||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
| if (temp_op_desc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s", | |||||
| op_desc->GetType().c_str()); | |||||
| GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
| GELOGW("InferFormatForSingleOp UpdateInputName failed"); | |||||
| } | |||||
| if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
| GELOGW("InferFormatForSingleOp UpdateOutputName failed"); | |||||
| } | |||||
| } | |||||
| node_op.BreakConnect(); | |||||
| } | |||||
| auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||||
| auto ret = op_desc->CallInferFormatFunc(op); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", | |||||
| op_desc->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | ||||
| const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | ||||
| bool is_offline, int32_t compile_flag) { | bool is_offline, int32_t compile_flag) { | ||||
| @@ -843,9 +879,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| Graph graph; | Graph graph; | ||||
| GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | ||||
| "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | ||||
| auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||||
| GE_CHK_STATUS_RET(op_desc->CallInferFormatFunc(op), | |||||
| "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); | |||||
| // 2. check engine type when compile online | // 2. check engine type when compile online | ||||
| if (model_file_name == kFileNameSuffix) { | if (model_file_name == kFileNameSuffix) { | ||||
| @@ -106,6 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator { | |||||
| bool CheckNoAicore(const ComputeGraphPtr &graph); | bool CheckNoAicore(const ComputeGraphPtr &graph); | ||||
| void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs); | void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs); | ||||
| Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | ||||
| Status InferFormatForSingleOp(OpDescPtr &op_desc); | |||||
| using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
| Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); | Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "../graph/passes/graph_builder_utils.h" | #include "../graph/passes/graph_builder_utils.h" | ||||
| #include "../graph/manager/graph_manager.h" | #include "../graph/manager/graph_manager.h" | ||||
| #include "all_ops.h" | #include "all_ops.h" | ||||
| @@ -79,6 +80,27 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { | |||||
| EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | ||||
| } | } | ||||
| */ | */ | ||||
| graphStatus TestFunc(Operator &op) { return 0; } | |||||
| graphStatus TestFunc1(Operator &op) { return 1; } | |||||
| TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { | |||||
| OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); | |||||
| shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add"); | |||||
| GeGenerator generator; | |||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); | |||||
| shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add"); | |||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); | |||||
| OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); | |||||
| shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2"); | |||||
| GeTensorDesc tensor_desc; | |||||
| EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
| EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); | |||||
| } | |||||
| TEST_F(UtestGeGenerator, test_build_single_op_online) { | TEST_F(UtestGeGenerator, test_build_single_op_online) { | ||||
| GeTensorDesc tensor_desc; | GeTensorDesc tensor_desc; | ||||
| @@ -144,6 +144,17 @@ REG_OP(Data) | |||||
| DT_UINT64, DT_BOOL, DT_DOUBLE})) | DT_UINT64, DT_BOOL, DT_DOUBLE})) | ||||
| .OP_END_FACTORY_REG(GuaranteeConst) | .OP_END_FACTORY_REG(GuaranteeConst) | ||||
| REG_OP(MatMulV2) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) | |||||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) | |||||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||||
| .ATTR(transpose_x1, Bool, false) | |||||
| .ATTR(transpose_x2, Bool, false) | |||||
| .ATTR(offset_x, Int, 0) | |||||
| .OP_END_FACTORY_REG(MatMulV2) | |||||
| IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) { | IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) { | ||||
| TensorDesc tensorDesc = op.GetInputDesc("x"); | TensorDesc tensorDesc = op.GetInputDesc("x"); | ||||
| (void)op.UpdateOutputDesc("y", tensorDesc); | (void)op.UpdateOutputDesc("y", tensorDesc); | ||||