| @@ -1 +1 @@ | |||||
| Subproject commit 0bdbad828640f03195c636f25cc834c381826bb1 | |||||
| Subproject commit 35de9facd31448995922246c5d2ffaa5a726bbb1 | |||||
| @@ -131,6 +131,7 @@ const char *YOLO2REORG = "Yolo2Reorg"; | |||||
| const char *REDUCESUM = "ReduceSum"; | const char *REDUCESUM = "ReduceSum"; | ||||
| const char *SUM = "Sum"; | const char *SUM = "Sum"; | ||||
| const char *CONSTANT = "Const"; | const char *CONSTANT = "Const"; | ||||
| const char *FILECONSTANT = "FileConstant"; | |||||
| const char *RESIZEBILINEAR = "ResizeBilinear"; | const char *RESIZEBILINEAR = "ResizeBilinear"; | ||||
| const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; | const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; | ||||
| const char *MAXIMUM = "Maximum"; | const char *MAXIMUM = "Maximum"; | ||||
| @@ -4,6 +4,7 @@ set(SRC_LIST | |||||
| "onnx_data_parser.cc" | "onnx_data_parser.cc" | ||||
| "onnx_util.cc" | "onnx_util.cc" | ||||
| "onnx_constant_parser.cc" | "onnx_constant_parser.cc" | ||||
| "onnx_file_constant_parser.cc" | |||||
| "subgraph_adapter/if_subgraph_adapter.cc" | "subgraph_adapter/if_subgraph_adapter.cc" | ||||
| "subgraph_adapter/subgraph_adapter_factory.cc" | "subgraph_adapter/subgraph_adapter_factory.cc" | ||||
| ) | ) | ||||
| @@ -17,6 +17,7 @@ PARSER_ONNX_SRC_FILES := \ | |||||
| onnx_data_parser.cc \ | onnx_data_parser.cc \ | ||||
| onnx_util.cc \ | onnx_util.cc \ | ||||
| onnx_constant_parser.cc \ | onnx_constant_parser.cc \ | ||||
| onnx_file_constant_parser.cc \ | |||||
| proto/onnx/ge_onnx.proto \ | proto/onnx/ge_onnx.proto \ | ||||
| proto/om.proto \ | proto/om.proto \ | ||||
| @@ -0,0 +1,150 @@ | |||||
| /** | |||||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "onnx_file_constant_parser.h" | |||||
| #include <vector> | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "parser/onnx/onnx_util.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "framework/common/types.h" | |||||
| using ge::onnx::NodeProto; | |||||
| using ge::onnx::TensorProto; | |||||
| using domi::ONNX; | |||||
| using GeShape = ge::GeShape; | |||||
| using GeTensorDesc = ge::GeTensorDesc; | |||||
| using namespace ge::parser; | |||||
| namespace { | |||||
| const std::string kAttrShape = "shape"; | |||||
| const std::string kAttrDataType = "dtype"; | |||||
| const std::string kFileConstantPath = "file_constant_path"; | |||||
| const std::string kLocation = "location"; | |||||
| const std::string kOffset = "offset"; | |||||
| const int64_t kOffsetCoefficient = 4096; | |||||
| const char *const kFileConstant = "FileConstant"; | |||||
| } | |||||
| namespace ge { | |||||
| Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
| GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| if (GetTensorProto(node, tensor_proto) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (ParseDataType(tensor_proto, op_def) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] parse data type failed", node->name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse data type failed", node->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (ParsePath(tensor_proto, op_def) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] parse file path failed", node->name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse file path failed", node->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ParseShape(tensor_proto, op_def); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, | |||||
| ge::onnx::TensorProto &tensor_proto) { | |||||
| for (const auto &it : node_proto->attribute()) { | |||||
| if (it.name() != ge::kAttrNameValue) { | |||||
| continue; | |||||
| } | |||||
| tensor_proto = it.t(); | |||||
| return SUCCESS; | |||||
| } | |||||
| REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); | |||||
| GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
| std::vector<int64_t> tmp_shape; | |||||
| for (int i = 0; i < tensor_proto.dims_size(); i++) { | |||||
| tmp_shape.push_back(tensor_proto.dims(i)); | |||||
| } | |||||
| op_def.SetAttr(kAttrShape.c_str(), tmp_shape); | |||||
| } | |||||
| Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
| int64_t data_type = tensor_proto.data_type(); | |||||
| ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | |||||
| if (type >= ge::DataType::DT_UNDEFINED) { | |||||
| REPORT_INNER_ERROR("E19999", "tensor_proto date type %ld is undefined.", data_type); | |||||
| GELOGE(domi::PARAM_INVALID, "[Check][Param] tensor_proto date type %ld is undefined.", data_type); | |||||
| return FAILED; | |||||
| } | |||||
| op_def.SetAttr(kAttrDataType.c_str(), type); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
| ge::NamedAttrs attrs; | |||||
| for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | |||||
| const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | |||||
| if (SetPathAttr(string_proto, attrs) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| if (!attrs.HasAttr(kLocation)) { | |||||
| REPORT_INNER_ERROR("E19999", "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| op_def.SetAttr(kFileConstantPath.c_str(), attrs); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | |||||
| ge::NamedAttrs &attrs) { | |||||
| if (string_proto.key() == kLocation) { | |||||
| AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | |||||
| } else { | |||||
| int64_t value; | |||||
| try { | |||||
| value = stol(string_proto.value()); | |||||
| } catch (const std::exception &e) { | |||||
| REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| return FAILED; | |||||
| } | |||||
| if (string_proto.key() == kOffset) { | |||||
| if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) { | |||||
| REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||||
| GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||||
| return FAILED; | |||||
| } | |||||
| value *= kOffsetCoefficient; | |||||
| } | |||||
| AttrUtils::SetInt(attrs, string_proto.key(), value); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
| #define GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
| #include "parser/onnx/onnx_op_parser.h" | |||||
| #include "proto/onnx/ge_onnx.pb.h" | |||||
| namespace ge { | |||||
| class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { | |||||
| public: | |||||
| Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | |||||
| private: | |||||
| Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
| Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
| void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
| Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); | |||||
| Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
| @@ -44,6 +44,12 @@ | |||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "subgraph_adapter/subgraph_adapter_factory.h" | #include "subgraph_adapter/subgraph_adapter_factory.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "mmpa/mmpa_api.h" | |||||
| namespace { | |||||
| const std::string kLocation = "location"; | |||||
| } | |||||
| namespace ge { | namespace ge { | ||||
| graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | ||||
| @@ -160,7 +166,8 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const std::map<std::string, std::string> kOnnxOpMap = { | const std::map<std::string, std::string> kOnnxOpMap = { | ||||
| {ge::kOpTypeInput, ge::parser::DATA}, | {ge::kOpTypeInput, ge::parser::DATA}, | ||||
| {ge::kOpTypeConstant, ge::parser::CONSTANT} | |||||
| {ge::kOpTypeConstant, ge::parser::CONSTANT}, | |||||
| {ge::kFileConstant, ge::parser::FILECONSTANT} | |||||
| }; | }; | ||||
| const int64_t kDimValue = 1; | const int64_t kDimValue = 1; | ||||
| @@ -350,12 +357,16 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
| ge::onnx::NodeProto *const_node = onnx_graph.add_node(); | ge::onnx::NodeProto *const_node = onnx_graph.add_node(); | ||||
| std::string output_name = it.first + "_" + to_string(index++); | std::string output_name = it.first + "_" + to_string(index++); | ||||
| const_node->set_name(output_name); | const_node->set_name(output_name); | ||||
| const_node->set_op_type(ge::kOpTypeConstant); | |||||
| const_node->add_output(it.first); | const_node->add_output(it.first); | ||||
| ge::onnx::AttributeProto *attribute = const_node->add_attribute(); | ge::onnx::AttributeProto *attribute = const_node->add_attribute(); | ||||
| attribute->set_name(ge::kAttrNameValue); | attribute->set_name(ge::kAttrNameValue); | ||||
| ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); | ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); | ||||
| *attribute_t = it.second; | *attribute_t = it.second; | ||||
| if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
| const_node->set_op_type(kFileConstant); | |||||
| } else { | |||||
| const_node->set_op_type(ge::kOpTypeConstant); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -723,6 +734,51 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto | |||||
| GELOGE(PARAM_INVALID, "[Read][ModeFile] failed."); | GELOGE(PARAM_INVALID, "[Read][ModeFile] failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (SetExternalPath(file, onnx_model) != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Set external path failed, file[%s]", file); | |||||
| GELOGE(PARAM_INVALID, "[Set][ExternalPath] failed."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxModelParser::SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const { | |||||
| std::string real_path = ge::parser::RealPath(file); | |||||
| const size_t file_len = real_path.length(); | |||||
| std::unique_ptr<char[]> tmp_file(new (std::nothrow) char[file_len + 1U]); | |||||
| GE_CHECK_NOTNULL(tmp_file); | |||||
| const auto ret = strncpy_s(tmp_file.get(), file_len + 1U, real_path.c_str(), file_len); | |||||
| if (ret != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu, ret=%d.", | |||||
| real_path.c_str(), tmp_file.get(), file_len, file_len + 1U, ret); | |||||
| GELOGE(FAILED, "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu.", | |||||
| real_path.c_str(), tmp_file.get(), file_len, file_len + 1U); | |||||
| return FAILED; | |||||
| } | |||||
| const char *const dir = mmDirName(tmp_file.get()); | |||||
| GE_CHECK_NOTNULL(dir); | |||||
| const ge::onnx::GraphProto &onnx_graph = onnx_model.graph(); | |||||
| for (int32_t i = 0; i < onnx_graph.initializer_size(); ++i) { | |||||
| const ge::onnx::TensorProto &initializer_tensor = onnx_graph.initializer(i); | |||||
| if (initializer_tensor.data_location() != ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
| continue; | |||||
| } | |||||
| for (int32_t j = 0; j < initializer_tensor.external_data_size(); ++j) { | |||||
| ge::onnx::StringStringEntryProto &string_proto = | |||||
| const_cast<ge::onnx::StringStringEntryProto &>(initializer_tensor.external_data(j)); | |||||
| if (string_proto.key() != kLocation) { | |||||
| continue; | |||||
| } | |||||
| const std::string &file_name = string_proto.value(); | |||||
| const std::string new_file = std::string(dir) + MMPA_PATH_SEPARATOR_STR + file_name; | |||||
| GELOGD("[%s] is external data. concat dir[%s] and file_name[%s], new_file[%s]", | |||||
| initializer_tensor.name().c_str(), dir, file_name.c_str(), new_file.c_str()); | |||||
| string_proto.set_value(new_file); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -126,6 +126,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
| Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const; | Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const; | ||||
| Status SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const; | |||||
| Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | ||||
| Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); | Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); | ||||
| @@ -48,6 +48,7 @@ const char *const kAttrNameIndex = "index"; | |||||
| const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | ||||
| const char *const kOpTypeConstant = "Constant"; | const char *const kOpTypeConstant = "Constant"; | ||||
| const char *const kOpTypeInput = "Input"; | const char *const kOpTypeInput = "Input"; | ||||
| const char *const kFileConstant = "FileConstant"; | |||||
| class OnnxUtil { | class OnnxUtil { | ||||
| public: | public: | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
| #include <string> | |||||
| typedef int mmErrorMSg; | typedef int mmErrorMSg; | ||||
| @@ -301,3 +302,22 @@ CHAR *mmGetErrorFormatMessage(mmErrorMSg errnum, CHAR *buf, mmSize size) | |||||
| } | } | ||||
| return strerror_r(errnum, buf, size); | return strerror_r(errnum, buf, size); | ||||
| } | } | ||||
| CHAR *mmDirName(CHAR *path) { | |||||
| if (path == NULL) { | |||||
| return NULL; | |||||
| } | |||||
| #if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) | |||||
| char separator = '\\'; | |||||
| #else | |||||
| char separator = '/'; | |||||
| #endif | |||||
| std::string path_str(path); | |||||
| const size_t last_sep_pos = path_str.rfind(separator); | |||||
| if (last_sep_pos == std::string::npos) { | |||||
| return NULL; | |||||
| } | |||||
| path[last_sep_pos] = '\0'; | |||||
| return path; | |||||
| } | |||||
| @@ -277,6 +277,7 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
| "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
| @@ -278,6 +278,7 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
| "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
| @@ -30,6 +30,7 @@ | |||||
| #define protected public | #define protected public | ||||
| #define private public | #define private public | ||||
| #include "parser/onnx/onnx_constant_parser.h" | #include "parser/onnx/onnx_constant_parser.h" | ||||
| #include "parser/onnx/onnx_file_constant_parser.h" | |||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| #include "parser/onnx/onnx_parser.h" | #include "parser/onnx/onnx_parser.h" | ||||
| #undef protected | #undef protected | ||||
| @@ -375,6 +376,190 @@ TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertDataType_test) | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| } | } | ||||
| TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::NodeProto input_node; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| Status ret = parser.GetTensorProto(&input_node, tensor_proto); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
| attribute->set_name("attribute"); | |||||
| attribute = input_node.add_attribute(); | |||||
| attribute->set_name("value"); | |||||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
| *attribute_tensor = tensor_proto; | |||||
| ret = parser.GetTensorProto(&input_node, tensor_proto); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParseShape) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| tensor_proto.add_dims(4); | |||||
| tensor_proto.add_dims(2); | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| parser.ParseShape(tensor_proto, op); | |||||
| std::vector<int64_t> attr_value; | |||||
| op.GetAttr("shape", attr_value); | |||||
| EXPECT_EQ(attr_value.size(), 2U); | |||||
| if (attr_value.size() == 2U) { | |||||
| EXPECT_EQ(attr_value[0], 4); | |||||
| EXPECT_EQ(attr_value[1], 2); | |||||
| } | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParseDataType) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| tensor_proto.set_data_type(OnnxDataType::UNDEFINED); | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| Status ret = parser.ParseDataType(tensor_proto, op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| tensor_proto.set_data_type(OnnxDataType::UINT8); | |||||
| ret = parser.ParseDataType(tensor_proto, op); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| ge::DataType attr_value; | |||||
| op.GetAttr("dtype", attr_value); | |||||
| EXPECT_EQ(attr_value, ge::DataType::DT_UINT8); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParseAttr) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::StringStringEntryProto string_proto; | |||||
| ge::NamedAttrs attrs; | |||||
| // test location | |||||
| string_proto.set_key("location"); | |||||
| string_proto.set_value("/usr/local"); | |||||
| Status ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| std::string attr_value; | |||||
| AttrUtils::GetStr(attrs, "location", attr_value); | |||||
| EXPECT_EQ(attr_value, "/usr/local"); | |||||
| // test offset | |||||
| string_proto.set_key("offset"); | |||||
| string_proto.set_value("123"); | |||||
| ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| int64_t offset_value; | |||||
| AttrUtils::GetInt(attrs, "offset", offset_value); | |||||
| EXPECT_EQ(offset_value, 123 * 4096); | |||||
| // offset overflow | |||||
| string_proto.set_key("offset"); | |||||
| string_proto.set_value("9223372036854775800"); | |||||
| ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // itol exception | |||||
| string_proto.set_key("offset"); | |||||
| string_proto.set_value("999999999999999999999999999999999999"); | |||||
| ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParsePath) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| // without location, error | |||||
| auto ret = parser.ParsePath(tensor_proto, op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // SetPathAttr error | |||||
| ge::onnx::StringStringEntryProto *offset_proto = tensor_proto.add_external_data(); | |||||
| offset_proto->set_key("offset"); | |||||
| offset_proto->set_value("999999999999999999999999999999"); | |||||
| ret = parser.ParsePath(tensor_proto, op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // has location, success | |||||
| ge::onnx::StringStringEntryProto *string_proto = tensor_proto.add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value("/usr/local"); | |||||
| offset_proto->set_key("offset"); | |||||
| offset_proto->set_value("0"); | |||||
| ret = parser.ParsePath(tensor_proto, op); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| // check location | |||||
| std::string attr_value; | |||||
| ge::NamedAttrs attrs; | |||||
| AttrUtils::GetNamedAttrs(op_desc_src, "file_constant_path", attrs); | |||||
| AttrUtils::GetStr(attrs, "location", attr_value); | |||||
| EXPECT_EQ(attr_value, "/usr/local"); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParseParam) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::NodeProto input_node; | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| // get tensor proto failed | |||||
| auto ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
| attribute->set_name("value"); | |||||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
| *attribute_tensor = tensor_proto; | |||||
| // parse data type failed | |||||
| attribute_tensor->set_data_type(OnnxDataType::UNDEFINED); | |||||
| ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // parse path failed | |||||
| attribute_tensor->set_data_type(OnnxDataType::UINT16); | |||||
| ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // success | |||||
| ge::onnx::StringStringEntryProto *string_proto = attribute_tensor->add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value("/usr/local"); | |||||
| attribute_tensor->add_dims(4); | |||||
| ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| // check location, shape, dtype | |||||
| NamedAttrs attrs; | |||||
| AttrUtils::GetNamedAttrs(*op_desc_src, "file_constant_path", attrs); | |||||
| std::string file_path; | |||||
| AttrUtils::GetStr(attrs, "location", file_path); | |||||
| EXPECT_EQ(file_path, "/usr/local"); | |||||
| std::vector<int64_t> dims; | |||||
| op.GetAttr("shape", dims); | |||||
| EXPECT_EQ(dims.size(), 1); | |||||
| if (!dims.empty()) { | |||||
| EXPECT_EQ(dims[0], 4); | |||||
| } | |||||
| DataType dtype; | |||||
| op.GetAttr("dtype", dtype); | |||||
| EXPECT_EQ(dtype, ge::DataType::DT_UINT16); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | ||||
| { | { | ||||
| OnnxModelParser model_parser; | OnnxModelParser model_parser; | ||||
| @@ -447,6 +632,25 @@ TEST_F(UtestOnnxParser, onnx_test_ModelParseToGraph) | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| } | } | ||||
| TEST_F(UtestOnnxParser, onnx_test_SetExternalPath) | |||||
| { | |||||
| OnnxModelParser modelParser; | |||||
| ge::onnx::ModelProto model_proto; | |||||
| auto ret = modelParser.SetExternalPath("", model_proto); | |||||
| EXPECT_NE(ret, SUCCESS); | |||||
| ge::onnx::GraphProto &graph_proto = const_cast<ge::onnx::GraphProto &>(model_proto.graph()); | |||||
| graph_proto.add_initializer(); | |||||
| ge::onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); | |||||
| tensor_proto->add_external_data(); | |||||
| ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value("if.onnx"); | |||||
| ret = modelParser.SetExternalPath("/usr/local", model_proto); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory) | TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory) | ||||
| { | { | ||||
| OnnxModelParser modelParser; | OnnxModelParser modelParser; | ||||