| @@ -29,17 +29,6 @@ const std::map<uint32_t, ge::DataType> onnx_data_type_map = { | |||
| {OnnxDataType::COMPLEX64, ge::DataType::DT_COMPLEX64}, {OnnxDataType::COMPLEX128, ge::DataType::DT_COMPLEX128}, | |||
| {OnnxDataType::BFLOAT16, ge::DataType::DT_UNDEFINED}, | |||
| }; | |||
| const std::map<uint32_t, int64_t> onnx_data_type_size_map = { | |||
| {OnnxDataType::FLOAT, sizeof(float)}, {OnnxDataType::UINT8, sizeof(uint8_t)}, | |||
| {OnnxDataType::INT8, sizeof(int8_t)}, {OnnxDataType::UINT16, sizeof(uint16_t)}, | |||
| {OnnxDataType::INT16, sizeof(int16_t)}, {OnnxDataType::INT32, sizeof(int32_t)}, | |||
| {OnnxDataType::INT64, sizeof(int64_t)}, {OnnxDataType::STRING, sizeof(std::string)}, | |||
| {OnnxDataType::BOOL, sizeof(bool)}, {OnnxDataType::FLOAT16, 2}, | |||
| {OnnxDataType::DOUBLE, sizeof(double)}, {OnnxDataType::UINT32, sizeof(uint32_t)}, | |||
| {OnnxDataType::UINT64, sizeof(uint64_t)}, {OnnxDataType::COMPLEX64, 8}, | |||
| {OnnxDataType::COMPLEX128, 16}, {OnnxDataType::BFLOAT16, 2}, | |||
| }; | |||
| } | |||
| namespace ge { | |||
| @@ -52,15 +41,6 @@ ge::DataType OnnxUtil::ConvertOnnxDataType(int64_t onnx_data_type) { | |||
| } | |||
| } | |||
| int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) { | |||
| auto search = onnx_data_type_size_map.find(onnx_data_type); | |||
| if (search != onnx_data_type_size_map.end()) { | |||
| return search->second; | |||
| } else { | |||
| return ge::DataType::DT_UNDEFINED; | |||
| } | |||
| } | |||
| void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | |||
| const std::string &parent_node_name, std::string &unique_subgraph_name) { | |||
| unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | |||
| @@ -52,7 +52,6 @@ const char *const kOpTypeInput = "Input"; | |||
| class OnnxUtil { | |||
| public: | |||
| static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | |||
| static int64_t CaculateDataSize(int64_t onnx_data_type); | |||
| static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | |||
| const std::string &parent_node_name, std::string &unique_subgraph_name); | |||
| }; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "external/graph/operator_reg.h" | |||
| #include "register/op_registry.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| namespace ge { | |||
| // for ir | |||
| @@ -99,6 +100,14 @@ REG_OP(Abs) | |||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||
| .OP_END_FACTORY_REG(Abs) | |||
| REG_OP(PartitionedCall) | |||
| .DYNAMIC_INPUT(args, TensorType::ALL()) | |||
| .DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||
| .GRAPH(f) | |||
| .ATTR(config, String, "") | |||
| .ATTR(config_proto, String, "") | |||
| .ATTR(executor_type, String, "") | |||
| .OP_END_FACTORY_REG(PartitionedCall) | |||
| // for plugin | |||
| static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||
| @@ -127,6 +136,29 @@ static Status ParseSubgraphPostFnIfStub(const std::string& subgraph_name, const | |||
| }); | |||
| } | |||
| static Status ParseParamsClipV9Stub(const Message* op_src, ge::Operator& op_dest) { | |||
| auto opDesc = ge::OpDescUtils::GetOpDescFromOperator(op_dest); | |||
| // 1.add dynamic input and out | |||
| opDesc->AddDynamicInputDesc("x", 1); | |||
| opDesc->AddDynamicOutputDesc("output", 1); | |||
| // 2.set original_type | |||
| ge::AttrUtils::SetStr(opDesc, "original_type", "ai.onnx::9::Clip"); | |||
| return SUCCESS; | |||
| } | |||
| static Status ParseOpToGraphClipV9Stub(const Operator& op, Graph& graph) { | |||
| auto data0 = op::Data("data0").set_attr_index(0); | |||
| auto abs0 = op::Abs("abs0").set_input_x(data0); | |||
| std::vector<Operator> inputs{data0}; | |||
| std::vector<std::pair<Operator, std::vector<size_t> > > output_indexs; | |||
| output_indexs.emplace_back(abs0, vector<std::size_t>{0}); | |||
| graph.SetInputs(inputs).SetOutputs(output_indexs); | |||
| return SUCCESS; | |||
| } | |||
| // caffe plugin | |||
| REGISTER_CUSTOM_OP("Data") | |||
| .FrameworkType(domi::CAFFE) | |||
| @@ -170,5 +202,12 @@ REGISTER_CUSTOM_OP("Add") | |||
| .FrameworkType(domi::TENSORFLOW) | |||
| .OriginOpType("Add") | |||
| .ParseParamsFn(ParseParamsStub); | |||
| REGISTER_CUSTOM_OP("PartitionedCall") | |||
| .FrameworkType(domi::ONNX) | |||
| .OriginOpType({"ai.onnx::9::Clip"}) | |||
| .ParseParamsFn(ParseParamsClipV9Stub) | |||
| .ParseOpToGraphFn(ParseOpToGraphClipV9Stub); | |||
| } // namespace ge | |||
| #endif // MAIN_OPS_STUB_H | |||
| @@ -16,6 +16,7 @@ | |||
| #include "st/parser_st_utils.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include <limits.h> | |||
| namespace ge { | |||
| void ParerSTestsUtils::ClearParserInnerCtx() { | |||
| @@ -41,4 +42,67 @@ void ParerSTestsUtils::ClearParserInnerCtx() { | |||
| ge::GetParserContext().enable_scope_fusion_passes = ""; | |||
| GELOGI("Clear parser inner context successfully."); | |||
| } | |||
| MemBuffer* ParerSTestsUtils::MemBufferFromFile(const char *path) { | |||
| char path_temp[PATH_MAX + 1] = {0x00}; | |||
| if(strlen(path) > PATH_MAX || nullptr == realpath(path, path_temp)) { | |||
| return nullptr; | |||
| } | |||
| FILE *fp = fopen(path_temp, "r+"); | |||
| if (fp == nullptr) { | |||
| return nullptr; | |||
| } | |||
| // get model file length | |||
| if (0 != fseek(fp, 0, SEEK_END)) { | |||
| fclose(fp); | |||
| return nullptr; | |||
| } | |||
| long file_length = ftell(fp); | |||
| if (fseek(fp, 0, SEEK_SET)) { | |||
| fclose(fp); | |||
| return nullptr; | |||
| } | |||
| if (file_length <= 0) { | |||
| fclose(fp); | |||
| return nullptr; | |||
| } | |||
| // alloc model buffer | |||
| void *data = malloc((unsigned int)file_length); | |||
| if (!data) { | |||
| fclose(fp); | |||
| return nullptr; | |||
| } | |||
| // read file into memory | |||
| uint32_t read_size = (uint32_t)fread(data, 1, (unsigned int)file_length, fp); | |||
| // check if read success | |||
| if ((long)read_size != file_length) { | |||
| free(data); | |||
| data = nullptr; | |||
| fclose(fp); | |||
| return nullptr; | |||
| } | |||
| // close model file | |||
| fclose(fp); | |||
| // create an MemBuffer | |||
| MemBuffer* membuf = new MemBuffer(); | |||
| if (!membuf) { | |||
| free(data); | |||
| data = nullptr; | |||
| return nullptr; | |||
| } | |||
| membuf->data = malloc((unsigned int)read_size); | |||
| // set size && data | |||
| membuf->size = (uint32_t)read_size; | |||
| memcpy((char*)membuf->data, (char*)data, read_size); | |||
| free(data); | |||
| return membuf; | |||
| } | |||
| } // namespace ge | |||
| @@ -20,9 +20,15 @@ | |||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||
| namespace ge { | |||
| struct MemBuffer { | |||
| void *data; | |||
| uint32_t size; | |||
| }; | |||
| class ParerSTestsUtils { | |||
| public: | |||
| static void ClearParserInnerCtx(); | |||
| static MemBuffer* MemBufferFromFile(const char *path); | |||
| }; | |||
| } // namespace ge | |||
| @@ -0,0 +1,28 @@ | |||
| import onnx | |||
| from onnx import helper | |||
| from onnx import AttributeProto, TensorProto, GraphProto | |||
| def make_clip_V9(): | |||
| X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4, 5]) | |||
| Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4, 5]) | |||
| node_def = helper.make_node('Clip', | |||
| inputs=['X'], | |||
| outputs=['Y'], | |||
| max = 1.0, | |||
| min = -1.0, | |||
| ) | |||
| graph = helper.make_graph( | |||
| [node_def], | |||
| "test_clip_case_V9", | |||
| [X], | |||
| [Y], | |||
| ) | |||
| model = helper.make_model(graph, producer_name="onnx-mul_test") | |||
| model.opset_import[0].version = 9 | |||
| onnx.save(model, "./onnx_clip_v9.onnx") | |||
| if __name__ == '__main__': | |||
| make_clip_V9() | |||
| @@ -24,6 +24,7 @@ | |||
| #include "st/parser_st_utils.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| #include "tests/depends/ops_stub/ops_stub.h" | |||
| #include "parser/onnx/onnx_parser.h" | |||
| namespace ge { | |||
| class STestOnnxParser : public testing::Test { | |||
| @@ -128,4 +129,48 @@ TEST_F(STestOnnxParser, onnx_parser_if_node) { | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||
| } | |||
| TEST_F(STestOnnxParser, onnx_parser_expand_one_to_many) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||
| MemBuffer *buffer = ParerSTestsUtils::MemBufferFromFile(model_file.c_str()); | |||
| ret = ge::aclgrphParseONNXFromMem(reinterpret_cast<char *>(buffer->data), buffer->size, parser_params, graph); | |||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||
| } | |||
| TEST_F(STestOnnxParser, onnx_parser_to_json) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| OnnxModelParser onnx_parser; | |||
| const char *json_file = "tmp.json"; | |||
| auto ret = onnx_parser.ToJson(model_file.c_str(), json_file); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| const char *json_null = nullptr; | |||
| ret = onnx_parser.ToJson(model_file.c_str(), json_null); | |||
| EXPECT_EQ(ret, FAILED); | |||
| const char *model_null = nullptr; | |||
| ret = onnx_parser.ToJson(model_null, json_null); | |||
| EXPECT_EQ(ret, FAILED); | |||
| } | |||
| TEST_F(STestOnnxParser, onnx_parser_const_data_type) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/onnx_const_type.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||
| } | |||
| } // namespace ge | |||