diff --git a/metadef b/metadef index 5d062a3..f5c1b6d 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 5d062a35640733026457c91966a558769570b0f8 +Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86 diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index f29904e..8a4b261 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra return SUCCESS; } +domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, + const std::string &input_data_names) const { + std::vector input_names = StringUtils::Split(input_data_names, ','); + std::unordered_map name_to_index; + for (auto &input_name : input_names) { + if (!name_to_index.emplace(input_name, name_to_index.size()).second) { + GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str()); + return FAILED; + } + } + + for (const NodePtr &node : graph->GetDirectNode()) { + if (node->GetType() != ge::parser::DATA) { + continue; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto iter = name_to_index.find(node->GetName()); + if (iter== name_to_index.cend()) { + GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names", + node->GetName().c_str()); + return FAILED; + } + GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld", + op_desc->GetName().c_str(), iter->second); + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) { + REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s", + ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); + GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); + return FAILED; + } + } + return SUCCESS; +} + void AclGrphParseUtil::CreateOutputNodesInfo(std::vector> &output_nodes_info, std::vector &output_nodes_name) const { output_nodes_name.clear(); @@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, return PARAM_INVALID; } + string input_data_names; + GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names); + if (!input_data_names.empty()) { + if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) { + GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s", + compute_graph->GetName().c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; } diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h index 8af1d27..4ff649f 100644 --- a/parser/common/acl_graph_parser_util.h +++ b/parser/common/acl_graph_parser_util.h @@ -61,6 +61,7 @@ class AclGrphParseUtil { size_t index, OpDescPtr &op_desc); domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, const string &is_input_adjust_hw_layout) const; + domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const; domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, std::vector> &output_nodes_info) const; }; diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 5b41752..06658f9 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -1029,7 +1029,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { ParserOperator unused("Add"); case_dir = case_dir.substr(0, case_dir.find_last_of("/")); std::string model_file = case_dir + "/origin_models/tf_add.pb"; - std::map parser_params; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, + }; ge::Graph graph; auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); ASSERT_EQ(ret, SUCCESS); @@ -1043,6 +1045,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); } +TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + ParserOperator unused("Add"); + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/tf_add.pb"; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")}, + }; + ge::Graph graph; + auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, ge::GRAPH_FAILED); +} + TEST_F(STestTensorflowParser, tensorflow_model_Failed) { ge::Graph graph; std::string caseDir = __FILE__; diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 00678d8..e3f6ea4 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { ParserOperator unused("Add"); case_dir = case_dir.substr(0, case_dir.find_last_of("/")); std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; - std::map parser_params; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, + }; ge::Graph graph; auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); ASSERT_EQ(ret, SUCCESS); @@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); } +TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + ParserOperator unused("Add"); + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")}, + }; + ge::Graph graph; + auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, ge::GRAPH_FAILED); +} + TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { ge::Graph graph; std::string caseDir = __FILE__;