Browse Source

!626 add input data names option

Merge pull request !626 from 刘江涛/ge_dev
pull/640/MERGE
刘江涛 i-robot 3 years ago
parent
commit
4ecf89acae
5 changed files with 83 additions and 3 deletions
  1. +1
    -1
      metadef
  2. +45
    -0
      parser/common/acl_graph_parser_util.cc
  3. +1
    -0
      parser/common/acl_graph_parser_util.h
  4. +18
    -1
      tests/st/testcase/test_tensorflow_parser.cc
  5. +18
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 5d062a35640733026457c91966a558769570b0f8
Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86

+ 45
- 0
parser/common/acl_graph_parser_util.cc View File

@@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra
return SUCCESS;
}

domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph,
const std::string &input_data_names) const {
std::vector<std::string> input_names = StringUtils::Split(input_data_names, ',');
std::unordered_map<std::string, size_t> name_to_index;
for (auto &input_name : input_names) {
if (!name_to_index.emplace(input_name, name_to_index.size()).second) {
GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str());
return FAILED;
}
}

for (const NodePtr &node : graph->GetDirectNode()) {
if (node->GetType() != ge::parser::DATA) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto iter = name_to_index.find(node->GetName());
if (iter== name_to_index.cend()) {
GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names",
node->GetName().c_str());
return FAILED;
}
GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld",
op_desc->GetName().c_str(), iter->second);
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) {
REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s",
ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str());
GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}

void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) const {
output_nodes_name.clear();
@@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
return PARAM_INVALID;
}

string input_data_names;
GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names);
if (!input_data_names.empty()) {
if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) {
GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s",
compute_graph->GetName().c_str());
return PARAM_INVALID;
}
}

return SUCCESS;
}



+ 1
- 0
parser/common/acl_graph_parser_util.h View File

@@ -61,6 +61,7 @@ class AclGrphParseUtil {
size_t index, OpDescPtr &op_desc);
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout) const;
domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const;
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const;
};


+ 18
- 1
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -1029,7 +1029,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) {
ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params;
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")},
};
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, SUCCESS);
@@ -1043,6 +1045,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) {
EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
}

TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) {
RegisterCustomOp();

std::string case_dir = __FILE__;
ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")},
};
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, ge::GRAPH_FAILED);
}

TEST_F(STestTensorflowParser, tensorflow_model_Failed) {
ge::Graph graph;
std::string caseDir = __FILE__;


+ 18
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) {
ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/tensorflow_model/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params;
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")},
};
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, SUCCESS);
@@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) {
EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
}

TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) {
RegisterCustomOp();

std::string case_dir = __FILE__;
ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/tensorflow_model/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")},
};
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, ge::GRAPH_FAILED);
}

TEST_F(UtestTensorflowParser, tensorflow_model_Failed) {
ge::Graph graph;
std::string caseDir = __FILE__;


Loading…
Cancel
Save