Browse Source

!578 超过2G的onnx模型导入

Merge pull request !578 from 徐剑/copy_branch
pull/582/MERGE
zhangfan Gitee 3 years ago
parent
commit
3277d89f76
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 478 additions and 3 deletions
  1. +1
    -1
      metadef
  2. +1
    -0
      parser/common/parser_types.cc
  3. +1
    -0
      parser/onnx/CMakeLists.txt
  4. +1
    -0
      parser/onnx/module.mk
  5. +150
    -0
      parser/onnx/onnx_file_constant_parser.cc
  6. +37
    -0
      parser/onnx/onnx_file_constant_parser.h
  7. +58
    -2
      parser/onnx/onnx_parser.cc
  8. +2
    -0
      parser/onnx/onnx_parser.h
  9. +1
    -0
      parser/onnx/onnx_util.h
  10. +20
    -0
      tests/depends/mmpa/src/mmpa_stub.cc
  11. +1
    -0
      tests/st/CMakeLists.txt
  12. +1
    -0
      tests/ut/parser/CMakeLists.txt
  13. +204
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 0bdbad828640f03195c636f25cc834c381826bb1
Subproject commit 35de9facd31448995922246c5d2ffaa5a726bbb1

+ 1
- 0
parser/common/parser_types.cc View File

@@ -131,6 +131,7 @@ const char *YOLO2REORG = "Yolo2Reorg";
const char *REDUCESUM = "ReduceSum"; const char *REDUCESUM = "ReduceSum";
const char *SUM = "Sum"; const char *SUM = "Sum";
const char *CONSTANT = "Const"; const char *CONSTANT = "Const";
const char *FILECONSTANT = "FileConstant";
const char *RESIZEBILINEAR = "ResizeBilinear"; const char *RESIZEBILINEAR = "ResizeBilinear";
const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad";
const char *MAXIMUM = "Maximum"; const char *MAXIMUM = "Maximum";


+ 1
- 0
parser/onnx/CMakeLists.txt View File

@@ -4,6 +4,7 @@ set(SRC_LIST
"onnx_data_parser.cc" "onnx_data_parser.cc"
"onnx_util.cc" "onnx_util.cc"
"onnx_constant_parser.cc" "onnx_constant_parser.cc"
"onnx_file_constant_parser.cc"
"subgraph_adapter/if_subgraph_adapter.cc" "subgraph_adapter/if_subgraph_adapter.cc"
"subgraph_adapter/subgraph_adapter_factory.cc" "subgraph_adapter/subgraph_adapter_factory.cc"
) )


+ 1
- 0
parser/onnx/module.mk View File

@@ -17,6 +17,7 @@ PARSER_ONNX_SRC_FILES := \
onnx_data_parser.cc \ onnx_data_parser.cc \
onnx_util.cc \ onnx_util.cc \
onnx_constant_parser.cc \ onnx_constant_parser.cc \
onnx_file_constant_parser.cc \
proto/onnx/ge_onnx.proto \ proto/onnx/ge_onnx.proto \
proto/om.proto \ proto/om.proto \




+ 150
- 0
parser/onnx/onnx_file_constant_parser.cc View File

@@ -0,0 +1,150 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "onnx_file_constant_parser.h"
#include <vector>

#include "graph/ge_tensor.h"
#include "parser/common/op_parser_factory.h"
#include "parser/onnx/onnx_util.h"
#include "framework/common/util.h"
#include "framework/common/types.h"

using ge::onnx::NodeProto;
using ge::onnx::TensorProto;
using domi::ONNX;
using GeShape = ge::GeShape;
using GeTensorDesc = ge::GeTensorDesc;
using namespace ge::parser;

namespace {
const std::string kAttrShape = "shape";
const std::string kAttrDataType = "dtype";
const std::string kFileConstantPath = "file_constant_path";
const std::string kLocation = "location";
const std::string kOffset = "offset";
const int64_t kOffsetCoefficient = 4096;
const char *const kFileConstant = "FileConstant";
}
namespace ge {
Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) {
GE_CHECK_NOTNULL(op_src);
const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src);
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str());

ge::onnx::TensorProto tensor_proto;
if (GetTensorProto(node, tensor_proto) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str());
return FAILED;
}
if (ParseDataType(tensor_proto, op_def) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] parse data type failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse data type failed", node->name().c_str());
return FAILED;
}
if (ParsePath(tensor_proto, op_def) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] parse file path failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse file path failed", node->name().c_str());
return FAILED;
}
ParseShape(tensor_proto, op_def);
return SUCCESS;
}

Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto,
ge::onnx::TensorProto &tensor_proto) {
for (const auto &it : node_proto->attribute()) {
if (it.name() != ge::kAttrNameValue) {
continue;
}
tensor_proto = it.t();
return SUCCESS;
}
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str());
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str());
return FAILED;
}

void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
std::vector<int64_t> tmp_shape;
for (int i = 0; i < tensor_proto.dims_size(); i++) {
tmp_shape.push_back(tensor_proto.dims(i));
}
op_def.SetAttr(kAttrShape.c_str(), tmp_shape);
}

Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
int64_t data_type = tensor_proto.data_type();
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type);
if (type >= ge::DataType::DT_UNDEFINED) {
REPORT_INNER_ERROR("E19999", "tensor_proto date type %ld is undefined.", data_type);
GELOGE(domi::PARAM_INVALID, "[Check][Param] tensor_proto date type %ld is undefined.", data_type);
return FAILED;
}

op_def.SetAttr(kAttrDataType.c_str(), type);
return SUCCESS;
}

Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
ge::NamedAttrs attrs;
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) {
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i);
if (SetPathAttr(string_proto, attrs) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str());
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str());
return FAILED;
}
}

if (!attrs.HasAttr(kLocation)) {
REPORT_INNER_ERROR("E19999", "external tensor proto[%s] must contain location.", tensor_proto.name().c_str());
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str());
return FAILED;
}
op_def.SetAttr(kFileConstantPath.c_str(), attrs);
return SUCCESS;
}

Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto,
ge::NamedAttrs &attrs) {
if (string_proto.key() == kLocation) {
AttrUtils::SetStr(attrs, kLocation, string_proto.value());
} else {
int64_t value;
try {
value = stol(string_proto.value());
} catch (const std::exception &e) {
REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what());
GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what());
return FAILED;
}
if (string_proto.key() == kOffset) {
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) {
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
return FAILED;
}
value *= kOffsetCoefficient;
}
AttrUtils::SetInt(attrs, string_proto.key(), value);
}
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser);
} // namespace ge

+ 37
- 0
parser/onnx/onnx_file_constant_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_
#define GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_

#include "parser/onnx/onnx_op_parser.h"
#include "proto/onnx/ge_onnx.pb.h"

namespace ge {
class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser {
public:
Status ParseParams(const Message *op_src, ge::Operator &op_def) override;

private:
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto);
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs);
};
} // namespace ge

#endif // GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_

+ 58
- 2
parser/onnx/onnx_parser.cc View File

@@ -44,6 +44,12 @@
#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "subgraph_adapter/subgraph_adapter_factory.h" #include "subgraph_adapter/subgraph_adapter_factory.h"
#include "framework/common/types.h"
#include "mmpa/mmpa_api.h"

namespace {
const std::string kLocation = "location";
}


namespace ge { namespace ge {
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
@@ -160,7 +166,8 @@ namespace ge {
namespace { namespace {
const std::map<std::string, std::string> kOnnxOpMap = { const std::map<std::string, std::string> kOnnxOpMap = {
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeInput, ge::parser::DATA},
{ge::kOpTypeConstant, ge::parser::CONSTANT}
{ge::kOpTypeConstant, ge::parser::CONSTANT},
{ge::kFileConstant, ge::parser::FILECONSTANT}
}; };
const int64_t kDimValue = 1; const int64_t kDimValue = 1;


@@ -350,12 +357,16 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
ge::onnx::NodeProto *const_node = onnx_graph.add_node(); ge::onnx::NodeProto *const_node = onnx_graph.add_node();
std::string output_name = it.first + "_" + to_string(index++); std::string output_name = it.first + "_" + to_string(index++);
const_node->set_name(output_name); const_node->set_name(output_name);
const_node->set_op_type(ge::kOpTypeConstant);
const_node->add_output(it.first); const_node->add_output(it.first);
ge::onnx::AttributeProto *attribute = const_node->add_attribute(); ge::onnx::AttributeProto *attribute = const_node->add_attribute();
attribute->set_name(ge::kAttrNameValue); attribute->set_name(ge::kAttrNameValue);
ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); ge::onnx::TensorProto *attribute_t = attribute->mutable_t();
*attribute_t = it.second; *attribute_t = it.second;
if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) {
const_node->set_op_type(kFileConstant);
} else {
const_node->set_op_type(ge::kOpTypeConstant);
}
} }


return SUCCESS; return SUCCESS;
@@ -723,6 +734,51 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto
GELOGE(PARAM_INVALID, "[Read][ModeFile] failed."); GELOGE(PARAM_INVALID, "[Read][ModeFile] failed.");
return FAILED; return FAILED;
} }

if (SetExternalPath(file, onnx_model) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Set external path failed, file[%s]", file);
GELOGE(PARAM_INVALID, "[Set][ExternalPath] failed.");
return PARAM_INVALID;
}
return SUCCESS;
}

Status OnnxModelParser::SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const {
std::string real_path = ge::parser::RealPath(file);
const size_t file_len = real_path.length();
std::unique_ptr<char[]> tmp_file(new (std::nothrow) char[file_len + 1U]);
GE_CHECK_NOTNULL(tmp_file);

const auto ret = strncpy_s(tmp_file.get(), file_len + 1U, real_path.c_str(), file_len);
if (ret != EN_OK) {
REPORT_CALL_ERROR("E19999", "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu, ret=%d.",
real_path.c_str(), tmp_file.get(), file_len, file_len + 1U, ret);
GELOGE(FAILED, "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu.",
real_path.c_str(), tmp_file.get(), file_len, file_len + 1U);
return FAILED;
}
const char *const dir = mmDirName(tmp_file.get());
GE_CHECK_NOTNULL(dir);

const ge::onnx::GraphProto &onnx_graph = onnx_model.graph();
for (int32_t i = 0; i < onnx_graph.initializer_size(); ++i) {
const ge::onnx::TensorProto &initializer_tensor = onnx_graph.initializer(i);
if (initializer_tensor.data_location() != ge::onnx::TensorProto_DataLocation_EXTERNAL) {
continue;
}
for (int32_t j = 0; j < initializer_tensor.external_data_size(); ++j) {
ge::onnx::StringStringEntryProto &string_proto =
const_cast<ge::onnx::StringStringEntryProto &>(initializer_tensor.external_data(j));
if (string_proto.key() != kLocation) {
continue;
}
const std::string &file_name = string_proto.value();
const std::string new_file = std::string(dir) + MMPA_PATH_SEPARATOR_STR + file_name;
GELOGD("[%s] is external data. concat dir[%s] and file_name[%s], new_file[%s]",
initializer_tensor.name().c_str(), dir, file_name.c_str(), new_file.c_str());
string_proto.set_value(new_file);
}
}
return SUCCESS; return SUCCESS;
} }




+ 2
- 0
parser/onnx/onnx_parser.h View File

@@ -126,6 +126,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const; Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const;


Status SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const;

Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const;


Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph);


+ 1
- 0
parser/onnx/onnx_util.h View File

@@ -48,6 +48,7 @@ const char *const kAttrNameIndex = "index";
const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; const char *const kAttrNameIsSubgraphOp = "is_subgraph_op";
const char *const kOpTypeConstant = "Constant"; const char *const kOpTypeConstant = "Constant";
const char *const kOpTypeInput = "Input"; const char *const kOpTypeInput = "Input";
const char *const kFileConstant = "FileConstant";


class OnnxUtil { class OnnxUtil {
public: public:


+ 20
- 0
tests/depends/mmpa/src/mmpa_stub.cc View File

@@ -15,6 +15,7 @@
*/ */


#include "mmpa/mmpa_api.h" #include "mmpa/mmpa_api.h"
#include <string>


typedef int mmErrorMSg; typedef int mmErrorMSg;


@@ -301,3 +302,22 @@ CHAR *mmGetErrorFormatMessage(mmErrorMSg errnum, CHAR *buf, mmSize size)
} }
return strerror_r(errnum, buf, size); return strerror_r(errnum, buf, size);
} }

CHAR *mmDirName(CHAR *path) {
if (path == NULL) {
return NULL;
}
#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER))
char separator = '\\';
#else
char separator = '/';
#endif
std::string path_str(path);
const size_t last_sep_pos = path_str.rfind(separator);
if (last_sep_pos == std::string::npos) {
return NULL;
}

path[last_sep_pos] = '\0';
return path;
}

+ 1
- 0
tests/st/CMakeLists.txt View File

@@ -277,6 +277,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/thread_pool.cc" "${PARSER_DIR}/parser/common/thread_pool.cc"
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc"
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_parser.cc"


+ 1
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -278,6 +278,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/thread_pool.cc" "${PARSER_DIR}/parser/common/thread_pool.cc"
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc"
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_parser.cc"


+ 204
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -30,6 +30,7 @@
#define protected public #define protected public
#define private public #define private public
#include "parser/onnx/onnx_constant_parser.h" #include "parser/onnx/onnx_constant_parser.h"
#include "parser/onnx/onnx_file_constant_parser.h"
#include "parser/onnx/onnx_util.h" #include "parser/onnx/onnx_util.h"
#include "parser/onnx/onnx_parser.h" #include "parser/onnx/onnx_parser.h"
#undef protected #undef protected
@@ -375,6 +376,190 @@ TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertDataType_test)
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);
} }


TEST_F(UtestOnnxParser, FileConstantGetTensorProto)
{
OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node;
ge::onnx::TensorProto tensor_proto;
Status ret = parser.GetTensorProto(&input_node, tensor_proto);
EXPECT_EQ(ret, FAILED);

ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("attribute");
attribute = input_node.add_attribute();
attribute->set_name("value");

ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto;
ret = parser.GetTensorProto(&input_node, tensor_proto);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestOnnxParser, FileConstantParseShape)
{
OnnxFileConstantParser parser;
ge::onnx::TensorProto tensor_proto;
tensor_proto.add_dims(4);
tensor_proto.add_dims(2);
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);

parser.ParseShape(tensor_proto, op);

std::vector<int64_t> attr_value;
op.GetAttr("shape", attr_value);
EXPECT_EQ(attr_value.size(), 2U);
if (attr_value.size() == 2U) {
EXPECT_EQ(attr_value[0], 4);
EXPECT_EQ(attr_value[1], 2);
}
}

TEST_F(UtestOnnxParser, FileConstantParseDataType)
{
OnnxFileConstantParser parser;
ge::onnx::TensorProto tensor_proto;
tensor_proto.set_data_type(OnnxDataType::UNDEFINED);
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);

Status ret = parser.ParseDataType(tensor_proto, op);
EXPECT_EQ(ret, FAILED);

tensor_proto.set_data_type(OnnxDataType::UINT8);
ret = parser.ParseDataType(tensor_proto, op);
EXPECT_EQ(ret, SUCCESS);
ge::DataType attr_value;
op.GetAttr("dtype", attr_value);
EXPECT_EQ(attr_value, ge::DataType::DT_UINT8);
}

TEST_F(UtestOnnxParser, FileConstantParseAttr)
{
OnnxFileConstantParser parser;
ge::onnx::StringStringEntryProto string_proto;
ge::NamedAttrs attrs;

// test location
string_proto.set_key("location");
string_proto.set_value("/usr/local");
Status ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, SUCCESS);
std::string attr_value;
AttrUtils::GetStr(attrs, "location", attr_value);
EXPECT_EQ(attr_value, "/usr/local");

// test offset
string_proto.set_key("offset");
string_proto.set_value("123");
ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, SUCCESS);
int64_t offset_value;
AttrUtils::GetInt(attrs, "offset", offset_value);
EXPECT_EQ(offset_value, 123 * 4096);

// offset overflow
string_proto.set_key("offset");
string_proto.set_value("9223372036854775800");
ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, FAILED);

// itol exception
string_proto.set_key("offset");
string_proto.set_value("999999999999999999999999999999999999");
ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestOnnxParser, FileConstantParsePath)
{
OnnxFileConstantParser parser;
ge::onnx::TensorProto tensor_proto;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);


// without location, error
auto ret = parser.ParsePath(tensor_proto, op);
EXPECT_EQ(ret, FAILED);

// SetPathAttr error
ge::onnx::StringStringEntryProto *offset_proto = tensor_proto.add_external_data();
offset_proto->set_key("offset");
offset_proto->set_value("999999999999999999999999999999");
ret = parser.ParsePath(tensor_proto, op);
EXPECT_EQ(ret, FAILED);

// has location, success
ge::onnx::StringStringEntryProto *string_proto = tensor_proto.add_external_data();
string_proto->set_key("location");
string_proto->set_value("/usr/local");
offset_proto->set_key("offset");
offset_proto->set_value("0");
ret = parser.ParsePath(tensor_proto, op);
EXPECT_EQ(ret, SUCCESS);

// check location
std::string attr_value;
ge::NamedAttrs attrs;
AttrUtils::GetNamedAttrs(op_desc_src, "file_constant_path", attrs);
AttrUtils::GetStr(attrs, "location", attr_value);
EXPECT_EQ(attr_value, "/usr/local");
}

TEST_F(UtestOnnxParser, FileConstantParseParam)
{
OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);

// get tensor proto failed
auto ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, FAILED);

ge::onnx::TensorProto tensor_proto;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("value");
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto;

// parse data type failed
attribute_tensor->set_data_type(OnnxDataType::UNDEFINED);
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, FAILED);

// parse path failed
attribute_tensor->set_data_type(OnnxDataType::UINT16);
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, FAILED);

// success
ge::onnx::StringStringEntryProto *string_proto = attribute_tensor->add_external_data();
string_proto->set_key("location");
string_proto->set_value("/usr/local");
attribute_tensor->add_dims(4);
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, SUCCESS);

// check location, shape, dtype
NamedAttrs attrs;
AttrUtils::GetNamedAttrs(*op_desc_src, "file_constant_path", attrs);
std::string file_path;
AttrUtils::GetStr(attrs, "location", file_path);
EXPECT_EQ(file_path, "/usr/local");

std::vector<int64_t> dims;
op.GetAttr("shape", dims);
EXPECT_EQ(dims.size(), 1);
if (!dims.empty()) {
EXPECT_EQ(dims[0], 4);
}
DataType dtype;
op.GetAttr("dtype", dtype);
EXPECT_EQ(dtype, ge::DataType::DT_UINT16);
}

TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test)
{ {
OnnxModelParser model_parser; OnnxModelParser model_parser;
@@ -447,6 +632,25 @@ TEST_F(UtestOnnxParser, onnx_test_ModelParseToGraph)
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);
} }


TEST_F(UtestOnnxParser, onnx_test_SetExternalPath)
{
OnnxModelParser modelParser;
ge::onnx::ModelProto model_proto;
auto ret = modelParser.SetExternalPath("", model_proto);
EXPECT_NE(ret, SUCCESS);

ge::onnx::GraphProto &graph_proto = const_cast<ge::onnx::GraphProto &>(model_proto.graph());
graph_proto.add_initializer();
ge::onnx::TensorProto* tensor_proto = graph_proto.add_initializer();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
tensor_proto->add_external_data();
ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data();
string_proto->set_key("location");
string_proto->set_value("if.onnx");
ret = modelParser.SetExternalPath("/usr/local", model_proto);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory) TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory)
{ {
OnnxModelParser modelParser; OnnxModelParser modelParser;


Loading…
Cancel
Save