From 415559095538a2959307b2634d749368318a2199 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 20 Mar 2021 18:00:29 +0800 Subject: [PATCH 1/4] Fix bug of const input index. --- ge/generator/ge_generator.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index c138ec0d..e2426682 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -850,6 +850,7 @@ Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector(graph_name); GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); + // 1. Add Node to ComputeGraph. NodePtr op_node = compute_graph->AddNode(op_desc); GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); From e1eb148756b27dc87d836a60e73afdcc0d098c56 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 20 Mar 2021 18:03:08 +0800 Subject: [PATCH 2/4] Fix bug of const input index. --- ge/generator/ge_generator.cc | 28 +++++++++++++++++++------- inc/framework/generator/ge_generator.h | 1 + 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index e2426682..2ff0c327 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -154,7 +154,7 @@ static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_ty } static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, - bool attr) { + bool attr, int32_t &data_index) { GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID); @@ -197,9 +197,10 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str()); GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str()); - if (attr) { - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, index), return FAILED, + if (attr && !is_const) { + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED, "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str()); + ++data_index; } ge::NodePtr arg_node = graph->AddNode(data_op); @@ -709,6 +710,17 @@ bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) { return true; } +void GeGenerator::RemoveConst(const vector &inputs, vector &outputs) { + for (auto input : inputs) { + GeTensorDesc input_desc = input.GetTensorDesc(); + bool is_const = false; + (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const); + if (!is_const) { + outputs.emplace_back(input); + } + } +} + Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs) { GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); @@ -773,7 +785,9 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in GELOGI("ATC parser success in single op build."); GeRootModelPtr ge_root_model = nullptr; - GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model)); + vector data_inputs; + RemoveConst(inputs, data_inputs); + GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model)); map op_attrs = op_desc_tmp->GetAllAttrs(); GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); @@ -850,25 +864,25 @@ Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector(graph_name); GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); - // 1. Add Node to ComputeGraph. NodePtr op_node = compute_graph->AddNode(op_desc); GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); // 2. Create InputData node. int32_t arg_index = 0; + int32_t data_index = 0; if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { continue; } - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index)); arg_index++; } } else { for (const auto &in_desc : inputs) { - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true)); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index)); arg_index++; } } diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 4b8caa95..505c7146 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -99,6 +99,7 @@ class GE_FUNC_VISIBILITY GeGenerator { const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline = true); bool CheckNoAicore(const ComputeGraphPtr &graph); + void RemoveConst(const vector &inputs, vector &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); using GeRootModelPtr = std::shared_ptr; From 6af2a247fdb46899e5783a5eb4f3fa69bff15f9f Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Tue, 23 Mar 2021 13:51:41 +0800 Subject: [PATCH 3/4] Add ut. --- ge/generator/ge_generator.cc | 2 +- tests/ut/ge/generator/ge_generator_unittest.cc | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 2ff0c327..65ae5501 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -714,7 +714,7 @@ void GeGenerator::RemoveConst(const vector &inputs, vector & for (auto input : inputs) { GeTensorDesc input_desc = input.GetTensorDesc(); bool is_const = false; - (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const); + (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); if (!is_const) { outputs.emplace_back(input); } diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 6d0db429..7b087e94 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -128,4 +128,13 @@ TEST_F(UtestGeGenerator, test_set_model_name) { ge_root_model->root_graph_ = std::move(graph); EXPECT_EQ(generator.SetModelNameForDump(ge_root_model), SUCCESS); } + +TEST_F(UtestGeGenerator, test_remove_const) { + GeGenerator generator; + GeTensorDesc tensor_desc; + GeTensor tensor(tensor_desc); + const vector inputs = {tensor}; + vector outputs; + generator.RemoveConst(inputs, outputs); +} } // namespace ge From 871efe285e579ff70e7cf267b27521b352f328e3 Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Tue, 23 Mar 2021 14:41:48 +0800 Subject: [PATCH 4/4] Add ut. --- ge/generator/ge_generator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 65ae5501..4cd5d34f 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -711,7 +711,7 @@ bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) { } void GeGenerator::RemoveConst(const vector &inputs, vector &outputs) { - for (auto input : inputs) { + for (auto &input : inputs) { GeTensorDesc input_desc = input.GetTensorDesc(); bool is_const = false; (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);