Browse Source

!1311 Fix bug of const input index.

From: @zhao_zhixuan
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
96eaa5364d
3 changed files with 31 additions and 6 deletions
  1. +21
    -6
      ge/generator/ge_generator.cc
  2. +1
    -0
      inc/framework/generator/ge_generator.h
  3. +9
    -0
      tests/ut/ge/generator/ge_generator_unittest.cc

+ 21
- 6
ge/generator/ge_generator.cc View File

@@ -154,7 +154,7 @@ static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_ty
} }


static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
bool attr) {
bool attr, int32_t &data_index) {
GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);


@@ -197,9 +197,10 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const
"[Add][InputDesc]fail for node:%s", data_op->GetName().c_str()); "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str());
GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
"[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str()); "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str());
if (attr) {
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, index), return FAILED,
if (attr && !is_const) {
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED,
"[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str()); "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
++data_index;
} }


ge::NodePtr arg_node = graph->AddNode(data_op); ge::NodePtr arg_node = graph->AddNode(data_op);
@@ -709,6 +710,17 @@ bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) {
return true; return true;
} }


void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
for (auto &input : inputs) {
GeTensorDesc input_desc = input.GetTensorDesc();
bool is_const = false;
(void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
if (!is_const) {
outputs.emplace_back(input);
}
}
}

Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
const vector<GeTensor> &outputs) { const vector<GeTensor> &outputs) {
GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
@@ -773,7 +785,9 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
GELOGI("ATC parser success in single op build."); GELOGI("ATC parser success in single op build.");


GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model));
vector<GeTensor> data_inputs;
RemoveConst(inputs, data_inputs);
GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs(); map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model);
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
@@ -856,18 +870,19 @@ Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor


// 2. Create InputData node. // 2. Create InputData node.
int32_t arg_index = 0; int32_t arg_index = 0;
int32_t data_index = 0;
if (inputs.empty()) { if (inputs.empty()) {
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
continue; continue;
} }
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
arg_index++; arg_index++;
} }
} else { } else {
for (const auto &in_desc : inputs) { for (const auto &in_desc : inputs) {
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true));
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
arg_index++; arg_index++;
} }
} }


+ 1
- 0
inc/framework/generator/ge_generator.h View File

@@ -99,6 +99,7 @@ class GE_FUNC_VISIBILITY GeGenerator {
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline = true); bool is_offline = true);
bool CheckNoAicore(const ComputeGraphPtr &graph); bool CheckNoAicore(const ComputeGraphPtr &graph);
void RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs);
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);


using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;


+ 9
- 0
tests/ut/ge/generator/ge_generator_unittest.cc View File

@@ -128,4 +128,13 @@ TEST_F(UtestGeGenerator, test_set_model_name) {
ge_root_model->root_graph_ = std::move(graph); ge_root_model->root_graph_ = std::move(graph);
EXPECT_EQ(generator.SetModelNameForDump(ge_root_model), SUCCESS); EXPECT_EQ(generator.SetModelNameForDump(ge_root_model), SUCCESS);
} }

TEST_F(UtestGeGenerator, test_remove_const) {
GeGenerator generator;
GeTensorDesc tensor_desc;
GeTensor tensor(tensor_desc);
const vector<GeTensor> inputs = {tensor};
vector<GeTensor> outputs;
generator.RemoveConst(inputs, outputs);
}
} // namespace ge } // namespace ge

Loading…
Cancel
Save