Merge pull request !626 from 刘江涛/ge_devpull/640/MERGE
| @@ -1 +1 @@ | |||||
| Subproject commit 5d062a35640733026457c91966a558769570b0f8 | |||||
| Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86 | |||||
| @@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, | |||||
| const std::string &input_data_names) const { | |||||
| std::vector<std::string> input_names = StringUtils::Split(input_data_names, ','); | |||||
| std::unordered_map<std::string, size_t> name_to_index; | |||||
| for (auto &input_name : input_names) { | |||||
| if (!name_to_index.emplace(input_name, name_to_index.size()).second) { | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||||
| if (node->GetType() != ge::parser::DATA) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto iter = name_to_index.find(node->GetName()); | |||||
| if (iter== name_to_index.cend()) { | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names", | |||||
| node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld", | |||||
| op_desc->GetName().c_str(), iter->second); | |||||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) { | |||||
| REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s", | |||||
| ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); | |||||
| GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
| std::vector<std::string> &output_nodes_name) const { | std::vector<std::string> &output_nodes_name) const { | ||||
| output_nodes_name.clear(); | output_nodes_name.clear(); | ||||
| @@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| string input_data_names; | |||||
| GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names); | |||||
| if (!input_data_names.empty()) { | |||||
| if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s", | |||||
| compute_graph->GetName().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -61,6 +61,7 @@ class AclGrphParseUtil { | |||||
| size_t index, OpDescPtr &op_desc); | size_t index, OpDescPtr &op_desc); | ||||
| domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | ||||
| const string &is_input_adjust_hw_layout) const; | const string &is_input_adjust_hw_layout) const; | ||||
| domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const; | |||||
| domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | ||||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const; | std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const; | ||||
| }; | }; | ||||
| @@ -1029,7 +1029,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||||
| ParserOperator unused("Add"); | ParserOperator unused("Add"); | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | std::string model_file = case_dir + "/origin_models/tf_add.pb"; | ||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, | |||||
| }; | |||||
| ge::Graph graph; | ge::Graph graph; | ||||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | ||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| @@ -1043,6 +1045,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | ||||
| } | } | ||||
| TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) { | |||||
| RegisterCustomOp(); | |||||
| std::string case_dir = __FILE__; | |||||
| ParserOperator unused("Add"); | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")}, | |||||
| }; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||||
| ASSERT_EQ(ret, ge::GRAPH_FAILED); | |||||
| } | |||||
| TEST_F(STestTensorflowParser, tensorflow_model_Failed) { | TEST_F(STestTensorflowParser, tensorflow_model_Failed) { | ||||
| ge::Graph graph; | ge::Graph graph; | ||||
| std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||
| @@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||||
| ParserOperator unused("Add"); | ParserOperator unused("Add"); | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | ||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, | |||||
| }; | |||||
| ge::Graph graph; | ge::Graph graph; | ||||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | ||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| @@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | ||||
| } | } | ||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) { | |||||
| RegisterCustomOp(); | |||||
| std::string case_dir = __FILE__; | |||||
| ParserOperator unused("Add"); | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")}, | |||||
| }; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||||
| ASSERT_EQ(ret, ge::GRAPH_FAILED); | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { | TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { | ||||
| ge::Graph graph; | ge::Graph graph; | ||||
| std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||