diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index c138ec0d..4cd5d34f 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -154,7 +154,7 @@ static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_ty } static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, - bool attr) { + bool attr, int32_t &data_index) { GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID); @@ -197,9 +197,10 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str()); GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str()); - if (attr) { - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, index), return FAILED, + if (attr && !is_const) { + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED, "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str()); + ++data_index; } ge::NodePtr arg_node = graph->AddNode(data_op); @@ -709,6 +710,17 @@ bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) { return true; } +void GeGenerator::RemoveConst(const vector &inputs, vector &outputs) { + for (auto &input : inputs) { + GeTensorDesc input_desc = input.GetTensorDesc(); + bool is_const = false; + (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); + if (!is_const) { + outputs.emplace_back(input); + } + } +} + Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs) { GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); @@ -773,7 +785,9 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in GELOGI("ATC parser success in single op build."); GeRootModelPtr ge_root_model = nullptr; - GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model)); + vector data_inputs; + RemoveConst(inputs, data_inputs); + GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model)); map op_attrs = op_desc_tmp->GetAllAttrs(); GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); @@ -856,18 +870,19 @@ Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vectorGetAllInputsDescPtr()) { GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { continue; } - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index)); arg_index++; } } else { for (const auto &in_desc : inputs) { - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true)); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index)); arg_index++; } } diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 4b8caa95..505c7146 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -99,6 +99,7 @@ class GE_FUNC_VISIBILITY GeGenerator { const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline = true); bool CheckNoAicore(const ComputeGraphPtr &graph); + void RemoveConst(const vector &inputs, vector &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); using GeRootModelPtr = std::shared_ptr; diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 6d0db429..7b087e94 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -128,4 +128,13 @@ TEST_F(UtestGeGenerator, test_set_model_name) { ge_root_model->root_graph_ = std::move(graph); EXPECT_EQ(generator.SetModelNameForDump(ge_root_model), SUCCESS); } + +TEST_F(UtestGeGenerator, test_remove_const) { + GeGenerator generator; + GeTensorDesc tensor_desc; + GeTensor tensor(tensor_desc); + const vector inputs = {tensor}; + vector outputs; + generator.RemoveConst(inputs, outputs); +} } // namespace ge