| @@ -117,6 +117,7 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| op_def.SetAttr(kFileConstantPath, attrs); | op_def.SetAttr(kFileConstantPath, attrs); | ||||
| GELOGD("The weight file of Op[%s] is: [%s].", tensor_proto.name().c_str(), attrs.GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -366,6 +366,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
| *attribute_t = it.second; | *attribute_t = it.second; | ||||
| if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | ||||
| const_node->set_op_type(kFileConstant); | const_node->set_op_type(kFileConstant); | ||||
| GELOGD("Initializer const node [%s], the weight was stored in the file.", const_node->name().c_str()); | |||||
| } else { | } else { | ||||
| const_node->set_op_type(ge::kOpTypeConstant); | const_node->set_op_type(ge::kOpTypeConstant); | ||||
| } | } | ||||
| @@ -374,7 +375,21 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const { | |||||
| void OnnxModelParser::UpdateConstantOpType(ge::onnx::NodeProto *node) const { | |||||
| // If weight in file, Marker Constant(not Initializer) as file constant | |||||
| for (auto it : node->attribute()) { | |||||
| if (it.name() == ge::kAttrNameValue) { | |||||
| const ::ge::onnx::TensorProto tensor_proto = it.t(); | |||||
| if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
| node->set_op_type(kFileConstant); | |||||
| GELOGD("Const node [%s], the weight was stored in the file.", node->name().c_str()); | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const { | |||||
| int index = 0; | int index = 0; | ||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
| ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | ||||
| @@ -382,6 +397,9 @@ void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const | |||||
| std::string node_name = node->op_type() + "_" + to_string(index++); | std::string node_name = node->op_type() + "_" + to_string(index++); | ||||
| node->set_name(node_name); | node->set_name(node_name); | ||||
| } | } | ||||
| if (node->op_type() == kOpTypeConstant) { | |||||
| UpdateConstantOpType(node); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -966,7 +984,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
| } | } | ||||
| GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | ||||
| // 3. Parse Constant from graph. | |||||
| // 3. Parse Constant(initializer) from graph. | |||||
| ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[Parse][Initializer] for onnx failed."); | GELOGE(ret, "[Parse][Initializer] for onnx failed."); | ||||
| @@ -980,8 +998,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // 5. Update node name for node do not has name. | |||||
| UpdateAllNodeName(onnx_graph); | |||||
| // 5. Update node name for node do not has name, update const op type | |||||
| UpdateNodeNameAndOpType(onnx_graph); | |||||
| // 6 Precheck. | // 6 Precheck. | ||||
| ret = Prechecker(onnx_graph); | ret = Prechecker(onnx_graph); | ||||
| @@ -105,7 +105,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
| Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | ||||
| std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | ||||
| void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const; | |||||
| void UpdateConstantOpType(ge::onnx::NodeProto *node) const; | |||||
| void UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const; | |||||
| Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | ||||
| @@ -25,7 +25,10 @@ | |||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | #include "tests/depends/ops_stub/ops_stub.h" | ||||
| #include "framework/omg/parser/parser_factory.h" | #include "framework/omg/parser/parser_factory.h" | ||||
| #include "parser/onnx/onnx_util.h" | |||||
| #define private public | |||||
| #include "parser/onnx/onnx_parser.h" | #include "parser/onnx/onnx_parser.h" | ||||
| #undef private | |||||
| namespace ge { | namespace ge { | ||||
| class STestOnnxParser : public testing::Test { | class STestOnnxParser : public testing::Test { | ||||
| @@ -103,6 +106,31 @@ void STestOnnxParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| ge::onnx::GraphProto CreateOnnxGraph() { | |||||
| ge::onnx::GraphProto onnx_graph; | |||||
| (void)onnx_graph.add_input(); | |||||
| (void)onnx_graph.add_output(); | |||||
| ::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||||
| node_const1->set_op_type(kOpTypeConstant); | |||||
| node_const2->set_op_type(kOpTypeConstant); | |||||
| node_add->set_op_type("Add"); | |||||
| ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
| attr = node_const1->add_attribute(); | |||||
| attr = node_const2->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||||
| return onnx_graph; | |||||
| } | |||||
| TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | ||||
| std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| @@ -184,4 +212,15 @@ TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
| } | } | ||||
| TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph) | |||||
| { | |||||
| OnnxModelParser modelParser; | |||||
| ge::onnx::ModelProto model_proto; | |||||
| auto onnx_graph = model_proto.mutable_graph(); | |||||
| *onnx_graph = CreateOnnxGraph(); | |||||
| ge::Graph graph; | |||||
| Status ret = modelParser.ModelParseToGraph(model_proto, graph); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -111,6 +111,29 @@ void UtestOnnxParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| ge::onnx::GraphProto CreateOnnxGraph() { | |||||
| ge::onnx::GraphProto onnx_graph; | |||||
| ::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||||
| node_const1->set_op_type(kOpTypeConstant); | |||||
| node_const2->set_op_type(kOpTypeConstant); | |||||
| node_add->set_op_type("Add"); | |||||
| ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
| attr = node_const1->add_attribute(); | |||||
| attr = node_const2->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||||
| return onnx_graph; | |||||
| } | |||||
| TEST_F(UtestOnnxParser, onnx_parser_if_node) { | TEST_F(UtestOnnxParser, onnx_parser_if_node) { | ||||
| std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| @@ -575,6 +598,16 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | |||||
| EXPECT_EQ(ret, domi::FAILED); | EXPECT_EQ(ret, domi::FAILED); | ||||
| } | } | ||||
| TEST_F(UtestOnnxParser, OnnxModelParser_ParseConstant_test) | |||||
| { | |||||
| OnnxModelParser model_parser; | |||||
| ge::onnx::GraphProto onnx_graph = CreateOnnxGraph(); | |||||
| model_parser.UpdateNodeNameAndOpType(onnx_graph); | |||||
| std::string type = onnx_graph.mutable_node(0)->op_type(); | |||||
| EXPECT_EQ(type, kFileConstant); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | ||||
| { | { | ||||
| ge::onnx::ModelProto model_proto; | ge::onnx::ModelProto model_proto; | ||||