Merge pull request !676 from zhangfan/ge_devpull/684/MERGE
| @@ -1 +1 @@ | |||||
| Subproject commit 599fbd9d7f9509b7673af90e186817b5a75ad547 | |||||
| Subproject commit f1af97e1c9ce9164901d4e719d3acaa1b8597d14 | |||||
| @@ -514,15 +514,10 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr | |||||
| if (!default_out_nodes.empty()) { | if (!default_out_nodes.empty()) { | ||||
| for (size_t i = 0; i < default_out_nodes.size(); ++i) { | for (size_t i = 0; i < default_out_nodes.size(); ++i) { | ||||
| ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); | ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); | ||||
| if (out_node == nullptr) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
| {"out_nodes", default_out_nodes[i].first}); | |||||
| GELOGE(domi::FAILED, "[Check][Param] Can not find out_nodes(%zu) (%s) in graph.", | |||||
| i, default_out_nodes[i].first.c_str()); | |||||
| return domi::FAILED; | |||||
| if (out_node != nullptr) { | |||||
| output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); | |||||
| GELOGD("Get default output node:%s.", out_node->GetName().c_str()); | |||||
| } | } | ||||
| output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); | |||||
| GELOGD("Get default output node:%s.", out_node->GetName().c_str()); | |||||
| } | } | ||||
| return domi::SUCCESS; | return domi::SUCCESS; | ||||
| } | } | ||||
| @@ -672,12 +672,13 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
| } | } | ||||
| Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | ||||
| // subgraph might not have input, or isolated const nodes exist in the graph, | |||||
| // we use constant nodes as the start nodes of graph | |||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
| ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
| if (node->op_type() == kOpTypeConstant) { | |||||
| input_node_names_.emplace_back(node->name()); | |||||
| if (input_node_names_.empty()) { | |||||
| // subgraph might not have input, we use constant nodes as the start nodes of the graph, | |||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
| ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
| if (node->op_type() == kOpTypeConstant) { | |||||
| input_node_names_.emplace_back(node->name()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| for (auto in_name : input_node_names_) { | for (auto in_name : input_node_names_) { | ||||
| @@ -0,0 +1,24 @@ | |||||
| :ß | |||||
| ¡ | |||||
| X"If*K | |||||
| else_branch29 | |||||
| else_out"Constant else_bodyb | |||||
| else_out | |||||
| *K | |||||
| then_branch29 | |||||
| then_out"Constant then_bodyb | |||||
| then_out | |||||
| Y"Constantif_modelZ | |||||
| X | |||||
| b | |||||
| Y | |||||
| B | |||||
| @@ -0,0 +1,45 @@ | |||||
| import os | |||||
| import numpy as np | |||||
| import onnx | |||||
| def gen_onnx(): | |||||
| X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [5]) | |||||
| Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [5]) | |||||
| then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) | |||||
| else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) | |||||
| const_out_node = onnx.helper.make_node("Constant", inputs=[], outputs=["Y"]) | |||||
| then_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["then_out"]) | |||||
| else_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["else_out"]) | |||||
| then_body = onnx.helper.make_graph( | |||||
| [then_const_node], | |||||
| "then_body", | |||||
| [], | |||||
| [then_out] | |||||
| ) | |||||
| else_body = onnx.helper.make_graph( | |||||
| [else_const_node], | |||||
| "else_body", | |||||
| [], | |||||
| [else_out] | |||||
| ) | |||||
| if_node = onnx.helper.make_node("If", inputs=["X"], outputs=[], then_branch=then_body, else_branch=else_body) | |||||
| graph_def = onnx.helper.make_graph( | |||||
| [if_node, const_out_node], | |||||
| "if_model", | |||||
| [X], | |||||
| [Y] | |||||
| ) | |||||
| model_def = onnx.helper.make_model(graph_def) | |||||
| model_def.opset_import[0].version=11 | |||||
| onnx.save(model_def, "onnx_if_const_intput.onnx") | |||||
| print(model_def) | |||||
| if __name__ == "__main__": | |||||
| gen_onnx() | |||||
| @@ -174,4 +174,14 @@ TEST_F(STestOnnxParser, onnx_parser_const_data_type) { | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
| } | } | ||||
| TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/origin_models/onnx_if_const_intput.onnx"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||