Merge pull request !626 from 刘江涛/ge_devpull/640/MERGE
| @@ -1 +1 @@ | |||
| Subproject commit 5d062a35640733026457c91966a558769570b0f8 | |||
| Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86 | |||
| @@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||
| return SUCCESS; | |||
| } | |||
| domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, | |||
| const std::string &input_data_names) const { | |||
| std::vector<std::string> input_names = StringUtils::Split(input_data_names, ','); | |||
| std::unordered_map<std::string, size_t> name_to_index; | |||
| for (auto &input_name : input_names) { | |||
| if (!name_to_index.emplace(input_name, name_to_index.size()).second) { | |||
| GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||
| if (node->GetType() != ge::parser::DATA) { | |||
| continue; | |||
| } | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto iter = name_to_index.find(node->GetName()); | |||
| if (iter== name_to_index.cend()) { | |||
| GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names", | |||
| node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld", | |||
| op_desc->GetName().c_str(), iter->second); | |||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) { | |||
| REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s", | |||
| ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); | |||
| GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name) const { | |||
| output_nodes_name.clear(); | |||
| @@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||
| return PARAM_INVALID; | |||
| } | |||
| string input_data_names; | |||
| GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names); | |||
| if (!input_data_names.empty()) { | |||
| if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s", | |||
| compute_graph->GetName().c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -61,6 +61,7 @@ class AclGrphParseUtil { | |||
| size_t index, OpDescPtr &op_desc); | |||
| domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||
| const string &is_input_adjust_hw_layout) const; | |||
| domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const; | |||
| domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const; | |||
| }; | |||
| @@ -1029,7 +1029,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| @@ -1043,6 +1045,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) { | |||
| RegisterCustomOp(); | |||
| std::string case_dir = __FILE__; | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, ge::GRAPH_FAILED); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_model_Failed) { | |||
| ge::Graph graph; | |||
| std::string caseDir = __FILE__; | |||
| @@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| @@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) { | |||
| RegisterCustomOp(); | |||
| std::string case_dir = __FILE__; | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, ge::GRAPH_FAILED); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { | |||
| ge::Graph graph; | |||
| std::string caseDir = __FILE__; | |||