Browse Source

bugfix for infer format for single op

tags/v1.3.0
y00500818 3 years ago
parent
commit
899d143a47
4 changed files with 71 additions and 3 deletions
  1. +37
    -3
      ge/generator/ge_generator.cc
  2. +1
    -0
      inc/framework/generator/ge_generator.h
  3. +22
    -0
      tests/ut/ge/generator/ge_generator_unittest.cc
  4. +11
    -0
      tests/ut/ge/graph/ops_stub.h

+ 37
- 3
ge/generator/ge_generator.cc View File

@@ -31,6 +31,7 @@
#include "graph/ge_context.h"
#include "graph/manager/graph_manager.h"
#include "graph/manager/util/rt_context_util.h"
#include "graph/operator_factory_impl.h"
#include "graph/opsproto_manager.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/type_utils.h"
@@ -803,6 +804,41 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor>
return SUCCESS;
}

Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(op_desc);
if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) {
auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str());
} else {
GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
if (temp_op_desc == nullptr) {
REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s",
op_desc->GetType().c_str());
GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str());
return FAILED;
}
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("InferFormatForSingleOp UpdateInputName failed");
}
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("InferFormatForSingleOp UpdateOutputName failed");
}
}
node_op.BreakConnect();
}
auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
auto ret = op_desc->CallInferFormatFunc(op);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail",
op_desc->GetName().c_str());
GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str());
return FAILED;
}
return SUCCESS;
}

Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline, int32_t compile_flag) {
@@ -843,9 +879,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
Graph graph;
GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph),
"[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str());
auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
GE_CHK_STATUS_RET(op_desc->CallInferFormatFunc(op),
"[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str());
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc));

// 2. check engine type when compile online
if (model_file_name == kFileNameSuffix) {


+ 1
- 0
inc/framework/generator/ge_generator.h View File

@@ -106,6 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator {
bool CheckNoAicore(const ComputeGraphPtr &graph);
void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs);
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
Status InferFormatForSingleOp(OpDescPtr &op_desc);

using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;
Status SetModelNameForDump(const GeRootModelPtr &ge_root_model);


+ 22
- 0
tests/ut/ge/generator/ge_generator_unittest.cc View File

@@ -23,6 +23,7 @@
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/operator_factory_impl.h"
#include "../graph/passes/graph_builder_utils.h"
#include "../graph/manager/graph_manager.h"
#include "all_ops.h"
@@ -79,6 +80,27 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) {
EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
}
*/
graphStatus TestFunc(Operator &op) { return 0; }
graphStatus TestFunc1(Operator &op) { return 1; }
TEST_F(UtestGeGenerator, test_infer_format_for_single_op) {
OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc);
shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add");
GeGenerator generator;
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS);
shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add");
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS);
OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1);
shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2");
GeTensorDesc tensor_desc;
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS);
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED);
}

TEST_F(UtestGeGenerator, test_build_single_op_online) {
GeTensorDesc tensor_desc;


+ 11
- 0
tests/ut/ge/graph/ops_stub.h View File

@@ -144,6 +144,17 @@ REG_OP(Data)
DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(GuaranteeConst)

REG_OP(MatMulV2)
.INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
.INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
.OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
.ATTR(transpose_x1, Bool, false)
.ATTR(transpose_x2, Bool, false)
.ATTR(offset_x, Int, 0)
.OP_END_FACTORY_REG(MatMulV2)

IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) {
TensorDesc tensorDesc = op.GetInputDesc("x");
(void)op.UpdateOutputDesc("y", tensorDesc);


Loading…
Cancel
Save