Browse Source

!570 sync parser to master 20220616

Merge pull request !570 from yangyongqiang/ge_dev
pull/575/MERGE
zhangfan Gitee 3 years ago
parent
commit
e3b8420661
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 470 additions and 7 deletions
  1. +4
    -1
      parser/common/acl_graph_parser_util.cc
  2. +87
    -1
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  3. +1
    -0
      tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc
  4. +28
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc
  5. +76
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  6. +174
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/origin_models/getnext_dynamic_fusion.pbtxt
  7. +100
    -5
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 4
- 1
parser/common/acl_graph_parser_util.cc View File

@@ -22,6 +22,7 @@
#include <cstdlib> #include <cstdlib>
#include <ctime> #include <ctime>
#include <fstream> #include <fstream>
#include <atomic>


#include "common/string_util.h" #include "common/string_util.h"
#include "common/util.h" #include "common/util.h"
@@ -61,6 +62,7 @@ const std::set<domi::FrameworkType> kSupportTensorAsOutput = {
domi::CAFFE, domi::CAFFE,
domi::ONNX domi::ONNX
}; };
std::atomic<uint32_t> graph_name_index {};


static string GetSoPath() { static string GetSoPath() {
Dl_info dl_info; Dl_info dl_info;
@@ -637,7 +639,8 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin


string tmp_name; string tmp_name;
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name);
graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name;
graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" +
ge::parser::CurrentTimeInStr() + "_" + std::to_string(graph_name_index++)) : tmp_name;


string enable_scope_fusion_passes; string enable_scope_fusion_passes;
GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes); GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes);


+ 87
- 1
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -40,6 +40,7 @@
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "parser/caffe/caffe_reshape_parser.h" #include "parser/caffe/caffe_reshape_parser.h"
#include "common/op_map.h" #include "common/op_map.h"
#include "parser/common/prototype_pass_manager.h"
#undef protected #undef protected
#undef private #undef private


@@ -51,6 +52,7 @@


using namespace domi::caffe; using namespace domi::caffe;
using namespace ge; using namespace ge;
using CreateFn = std::function<ProtoTypeBasePass *(void)>;


namespace ge { namespace ge {
class UtestCaffeParser : public testing::Test { class UtestCaffeParser : public testing::Test {
@@ -66,6 +68,11 @@ class UtestCaffeParser : public testing::Test {
void RegisterCustomOp(); void RegisterCustomOp();
}; };


class RegisterPass : public ProtoTypeBasePass {
public:
Status Run(google::protobuf::Message *message) { return SUCCESS; }
};

static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){ static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){
if (!opDesc) { if (!opDesc) {
return nullptr; return nullptr;
@@ -835,6 +842,19 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerParameter_test)
{ {
CaffeWeightsParser weightParser; CaffeWeightsParser weightParser;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph"); ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph");
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape({1}));
tensor_desc->SetDataType(DT_FLOAT);
tensor_desc->SetFormat(FORMAT_CHWN);

auto op_desc = std::make_shared<OpDesc>("Abs", "Abs");
op_desc->AddInputDesc(tensor_desc->Clone());
auto node = compute_graph->AddNode(op_desc);
auto op_desc1 = std::make_shared<OpDesc>("Abs", "Abs");
op_desc1->AddInputDesc(tensor_desc->Clone());
auto nodeptr = compute_graph->AddNodeFront(node);

domi::caffe::NetParameter net; domi::caffe::NetParameter net;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Abs", "AbsVal"); ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Abs", "AbsVal");
domi::caffe::LayerParameter *layer = net.add_layer(); domi::caffe::LayerParameter *layer = net.add_layer();
@@ -1142,9 +1162,17 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test)
auto descriptor = importer.pool()->FindMessageTypeByName("domi.caffe.LayerParameter"); auto descriptor = importer.pool()->FindMessageTypeByName("domi.caffe.LayerParameter");
google::protobuf::DynamicMessageFactory factory; google::protobuf::DynamicMessageFactory factory;
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
const google::protobuf::Message *message = proto->New();
google::protobuf::Message *message = proto->New();
Status ret = modelParser.ParseLayerParameter(descriptor, message, operators); Status ret = modelParser.ParseLayerParameter(descriptor, message, operators);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);

const domi::FrameworkType fmk_type = domi::TENSORFLOW;
const char_t *const pass_name = "PASS_NAME";
auto func = [&](){ return new (std::nothrow) RegisterPass();};
CreateFn create_fn = func;
ProtoTypePassRegistry::GetInstance().RegisterProtoTypePass(pass_name, create_fn, fmk_type);
ret = ProtoTypePassManager::Instance().Run(message, fmk_type);
EXPECT_EQ(ret, SUCCESS);
delete message; delete message;
} }


@@ -1192,6 +1220,64 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ReorderInput_test)
layer2->set_name("Data"); layer2->set_name("Data");
layer2->set_type("Input"); layer2->set_type("Input");
modelParser.ReorderInput(net); modelParser.ReorderInput(net);

std::vector<int32_t> idx_vector = {0,1,2,4};
ge::GetParserContext().out_nodes_map.insert(pair<std::string, std::vector<int32_t>>("add", idx_vector));
const string op_name = "add";
const int32_t index = 0;
bool ret = modelParser.IsOutputTop(op_name, index);
EXPECT_EQ(ret, true);
}

TEST_F(UtestCaffeParser, CaffeOpParser_ParseParms_test)
{
CaffeOpParser parser;
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string caffe_proto = case_dir + "/../../../../../metadef/proto/caffe/";
google::protobuf::compiler::DiskSourceTree sourceTree;
sourceTree.MapPath("project_root", caffe_proto);
google::protobuf::compiler::Importer importer(&sourceTree, nullptr);
importer.Import("project_root/caffe.proto");
auto descriptor = importer.pool()->FindMessageTypeByName("domi.caffe.LayerParameter");
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Abs", "AbsVal");
google::protobuf::DynamicMessageFactory factory;
const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
const google::protobuf::Message *message = proto->New();
ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);
Status ret = parser.ParseParams(message, op_src);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestCaffeParser, CaffeModelParser_Constructor_and_delete)
{
CaffeModelParser modelParser;
domi::caffe::NetParameter net;
net.add_input("111");
bool input_data_flag = true;
net.add_input_shape();
Status ret = modelParser.ParseInput(net, input_data_flag);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestCaffeParser, ParseFromMemory_success_graph)
{
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
std::string modelFile = caseDir + "/caffe_model/caffe_add.pbtxt";
std::string weight_file = caseDir + "/caffe_model/caffe_add.caffemodel";

const char* tmp_tf_pb_model = modelFile.c_str();
const char* tmp_tf_weight_model = weight_file.c_str();
ge::Graph graph;

Status ret = ge::aclgrphParseCaffe(modelFile.c_str(), weight_file.c_str(), graph);
CaffeModelParser modelParser;
MemBuffer* memBuffer1 = ParerUTestsUtils::MemBufferFromFile(tmp_tf_pb_model);
ret = modelParser.ParseFromMemory((char*)memBuffer1->data, memBuffer1->size, graph);
EXPECT_EQ(ret, SUCCESS);
delete memBuffer1;
} }


} // namespace ge } // namespace ge

+ 1
- 0
tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc View File

@@ -344,6 +344,7 @@ TEST_F(UtestAclGraphParser, test_operatoreq)
} }


TEST_F(UtestAclGraphParser, test_pre_checker) { TEST_F(UtestAclGraphParser, test_pre_checker) {
TBEPluginLoader tbe_plugin;
PreChecker::Instance().fmk_op_types_ = nullptr; PreChecker::Instance().fmk_op_types_ = nullptr;
const char* str = "iiii"; const char* str = "iiii";
PreChecker::OpId id = str; PreChecker::OpId id = str;


+ 28
- 0
tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc View File

@@ -69,4 +69,32 @@ TEST_F(UtestMessage2Operator, pb2json_one_field_json) {
Json json; Json json;
ge::Pb2Json::Message2Json(input_node, std::set<std::string>{}, json, true); ge::Pb2Json::Message2Json(input_node, std::set<std::string>{}, json, true);
} }

TEST_F(UtestMessage2Operator, pb2json_one_field_json_depth_max) {
ge::onnx::NodeProto input_node;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("attribute");
attribute->set_type(onnx::AttributeProto::AttributeType(1));
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
attribute_tensor->set_data_type(1);
attribute_tensor->add_dims(4);
attribute_tensor->set_raw_data("\007");
Json json;
ge::Pb2Json::Message2Json(input_node, std::set<std::string>{}, json, true, 21);
}

TEST_F(UtestMessage2Operator, pb2json_one_field_json_type) {
ge::onnx::NodeProto input_node;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("attribute");
attribute->set_type(onnx::AttributeProto::AttributeType(1));
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
attribute_tensor->set_data_type(3);
attribute_tensor->add_dims(4);
attribute_tensor->set_raw_data("\007");
Json json;
ge::Pb2Json::Message2Json(input_node, std::set<std::string>{}, json, true);
}


} // namespace ge } // namespace ge

+ 76
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -218,6 +218,47 @@ TEST_F(UtestOnnxParser, onnx_parser_to_json) {
const char *model_null = nullptr; const char *model_null = nullptr;
ret = onnx_parser.ToJson(model_null, json_null); ret = onnx_parser.ToJson(model_null, json_null);
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);

char *data = nullptr;
uint32_t size = 0;
ge::ComputeGraphPtr graph;
ret = onnx_parser.ParseFromMemory(data, size, graph);
EXPECT_EQ(ret, SUCCESS);

google::protobuf::Message *proto = nullptr;
ret = onnx_parser.ParseProto(proto, graph);
EXPECT_EQ(ret, SUCCESS);

domi::GetGraphCallback callback;
ret = onnx_parser.ParseProtoWithSubgraph(proto, callback, graph);
EXPECT_EQ(ret, SUCCESS);

ret = onnx_parser.ParseAllGraph(proto, graph);
EXPECT_EQ(ret, SUCCESS);

string file = "./";
ret = onnx_parser.Save(file);
EXPECT_NE(ret, SUCCESS);

bool ret1 = onnx_parser.HasError();
EXPECT_EQ(ret1, SUCCESS);
onnx_parser.Clear();

OnnxWeightsParser onnx_weight_parser;
char *file1 = nullptr;
ge::Graph graph1;
ret = onnx_weight_parser.Parse(file1, graph1);
EXPECT_EQ(ret, SUCCESS);

ret = onnx_weight_parser.ParseFromMemory(data, size, graph);
EXPECT_EQ(ret, SUCCESS);

ret1 = onnx_weight_parser.HasError();
EXPECT_EQ(ret1, SUCCESS);

ret = onnx_weight_parser.Save(file);
EXPECT_NE(ret, SUCCESS);
onnx_weight_parser.Clear();
} }


TEST_F(UtestOnnxParser, onnx_parser_const_data_type) { TEST_F(UtestOnnxParser, onnx_parser_const_data_type) {
@@ -243,6 +284,7 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ConvertToGeDataType_test)
EXPECT_EQ(ret, ge::DataType::DT_UNDEFINED); EXPECT_EQ(ret, ge::DataType::DT_UNDEFINED);
} }



TEST_F(UtestOnnxParser, OnnxModelParser_ParseConvertData_test) TEST_F(UtestOnnxParser, OnnxModelParser_ParseConvertData_test)
{ {
OnnxConstantParser constant_parser; OnnxConstantParser constant_parser;
@@ -278,6 +320,23 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseConvertData_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


TEST_F(UtestOnnxParser, OnnxModelParser_ParseConvertData_test_bool)
{
OnnxConstantParser constant_parser;
ge::onnx::TensorProto tensor_proto;
tensor_proto.set_data_type(OnnxDataType::INT32);
ge::Tensor tensor ;
TensorDesc tensor_desc = tensor.GetTensorDesc();
tensor_desc.SetDataType(ge::DataType::DT_BOOL);
tensor.SetTensorDesc(tensor_desc);
int count = 1;
tensor_proto.set_raw_data("Test");
Status ret = constant_parser.ParseConvertData(tensor_proto, tensor, count);
EXPECT_EQ(ret, SUCCESS);

}


TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertTensor_test) TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertTensor_test)
{ {
OnnxConstantParser constant_parser; OnnxConstantParser constant_parser;
@@ -423,4 +482,21 @@ TEST_F(UtestOnnxParser, onnx_test_GetModelFromMemory)
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);
} }


TEST_F(UtestOnnxParser, onnx_test_TransNodeToOperator_SetTensorData)
{
ge::onnx::ModelProto model_proto;
ge::onnx::GraphProto* graph = model_proto.mutable_graph();
ge::onnx::NodeProto *node_proto = graph->add_node();
node_proto->set_op_type("Add1");
node_proto->set_domain("add.onnx");
node_proto->set_name("Conv2D");
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Add", "add.onnx");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);
std::string op_type = "Add";

OnnxModelParser onnx_parser;
Status ret = onnx_parser.TransNodeToOperator(node_proto, op, op_type);
EXPECT_EQ(ret, SUCCESS);
}

} // namespace ge } // namespace ge

+ 174
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/origin_models/getnext_dynamic_fusion.pbtxt View File

@@ -0,0 +1,174 @@
node {
name: "IteratorV2"
op: "IteratorV2"
attr {
key: "op_def"
value {
s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001"
}
}
attr {
key: "output_types"
value {
list {
type: DT_INT64
}
}
}
attr {
key: "output_tensor_desc"
value {
list {
func {
name: "0"
attr {
key: "serialize_datatype"
value: {
i: 9
}
}
attr {
key: "serialize_format"
value: {
i: 1
}
}
attr {
key: "serialize_shape"
value {
type: DT_INT32
}
}
}
}
}
}
}
node {
name: "IteratorGetNext"
op: "IteratorGetNext"
input: "IteratorV2"
attr {
key: "output_types"
value {
list {
type: DT_INT64
}
}
}
attr {
key: "op_def"
value {
s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001"
}
}
attr {
key: "input_tensor_desc"
value {
list {
func {
name: "0"
attr {
key: "serialize_datatype"
value: {
i: 9
}
}
attr {
key: "serialize_format"
value: {
i: 1
}
}
attr {
key: "serialize_shape"
value {
type: DT_INT32
}
}
}
}
}
}
attr {
key: "output_tensor_desc"
value {
list {
func {
name: "0"
attr {
key: "serialize_datatype"
value: {
i: 9
}
}
attr {
key: "serialize_format"
value: {
i: 1
}
}
attr {
key: "serialize_shape"
value {
list {
i: -1
i: -1
}
}
}
}
}
}
}
}
node {
name: "getnext_shape_0"
op: "Shape"
input: "IteratorGetNext"
attr {
key: "op_def"
value {
s: "\n\005Shape\022\n\n\005input\"\001T\032\022\n\006output\"\010out_type\"\t\n\001T\022\004type\"\034\n\010out_type\022\004type\032\0020\003:\006\n\0042\002\003\t"
}
}
}
node {
name: "retval_GetNext_0_0"
op: "_Retval"
input: "IteratorGetNext"
attr {
key: "index"
value {
i: 0
}
}
attr {
key: "op_def"
value {
s: ""
}
}
}
node {
name: "retval_GetNext_0_1"
op: "_Retval"
input: "getnext_shape_0"
attr {
key: "index"
value {
i: 1
}
}
attr {
key: "op_def"
value {
s: ""
}
}
}
library {
}
versions {
producer: 134
}

+ 100
- 5
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -1398,6 +1398,13 @@ TEST_F(UtestTensorflowParser, tensorflow_ParserProto_failed)
ASSERT_EQ(ret, PARAM_INVALID); ASSERT_EQ(ret, PARAM_INVALID);
} }


std::unique_ptr<google::protobuf::Message> getGraphCallback(const google::protobuf::Message *root_proto, const std::string &graph)
{
(void)root_proto;
(void)graph;
return nullptr;
}

TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed) TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed)
{ {
std::string caseDir = __FILE__; std::string caseDir = __FILE__;
@@ -1422,6 +1429,11 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed)
TensorFlowModelParser tensorflow_parser; TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
ASSERT_NE(ret, SUCCESS); ASSERT_NE(ret, SUCCESS);

domi::GetGraphCallback callback(&getGraphCallback);
const auto message_root_proto = reinterpret_cast<google::protobuf::Message *>(&graphDef);
ret = tensorflow_parser.ParseProtoWithSubgraph(message_root_proto, callback, root_graph);
ASSERT_NE(ret, SUCCESS);
} }


TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes)
@@ -3768,6 +3780,8 @@ TEST_F(UtestTensorflowParser, tensorflow_tbe_tfplugin_loader_test)
pluginLoad.ProcessSoFullName(fileList, caffeParserPath, full_name, caffe_parser_so_suff); pluginLoad.ProcessSoFullName(fileList, caffeParserPath, full_name, caffe_parser_so_suff);
ASSERT_EQ(caffeParserPath, full_name); ASSERT_EQ(caffeParserPath, full_name);


void *p = (void*)malloc(sizeof(int));
pluginLoad.handles_vec_.push_back(p);
pluginLoad.ClearHandles_(); pluginLoad.ClearHandles_();


std::cout << __FILE__ << std::endl; std::cout << __FILE__ << std::endl;
@@ -3942,9 +3956,19 @@ TEST_F(UtestTensorflowParser, custom_parser_adapter_register)
ASSERT_EQ(nullptr, func); ASSERT_EQ(nullptr, func);
} }


static Status ParseParamsStub1(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}

TEST_F(UtestTensorflowParser, tensorflow_parser_api_test) TEST_F(UtestTensorflowParser, tensorflow_parser_api_test)
{ {

REGISTER_CUSTOM_OP("Add11")
.FrameworkType(domi::TENSORFLOW)
.OriginOpType("Add11")
.ParseParamsFn(ParseParamsStub1);
std::map<std::string, std::string> options = {{"ge.runFlag", "1"}}; std::map<std::string, std::string> options = {{"ge.runFlag", "1"}};
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));
Status ret = ParserInitialize(options); Status ret = ParserInitialize(options);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


@@ -3958,6 +3982,24 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_api_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


TEST_F(UtestTensorflowParser, tensorflow_parser_api_test_cafee)
{
std::map<std::string, std::string> options = {{"ge.runFlag", "1"}};
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE)));
Status ret = ParserInitialize(options);
EXPECT_EQ(ret, SUCCESS);
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE)));

ret = ParserInitialize(options);
EXPECT_EQ(ret, SUCCESS);

ret = ParserFinalize();
EXPECT_EQ(ret, SUCCESS);

ret = ParserFinalize();
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_FP16_parser_test) TEST_F(UtestTensorflowParser, tensorflow_FP16_parser_test)
{ {
parser::fp16_t fp16; parser::fp16_t fp16;
@@ -4154,6 +4196,36 @@ TEST_F(UtestTensorflowParser, parser_UpdateGraph_test)
EXPECT_EQ(ret, PARAM_INVALID); EXPECT_EQ(ret, PARAM_INVALID);
} }


TEST_F(UtestTensorflowParser, tensorflow_optimizer_fmk_fusion_op_) {
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
const std::string root_proto = caseDir + "/origin_models/getnext_dynamic_fusion.pbtxt";
domi::tensorflow::GraphDef graphDef;

bool protoRet = parser::ReadProtoFromText(root_proto.c_str(), &graphDef);
ASSERT_EQ(protoRet, true);

TensorFlowModelParser tensorflow_parser;
ge::ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph");
Status ret = tensorflow_parser.ParseProto(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(root_graph->GetDirectNode().size(), 3);
}



TEST_F(UtestTensorflowParser, parser_UpdateGraph_node_0)
{
std::vector<NodePtr> nodes;
ge::ComputeGraphPtr subGraph = std::make_shared<ge::ComputeGraph>("default");
ParserGraphOptimizer graphOptimizer(subGraph, domi::TENSORFLOW);
Status ret = graphOptimizer.UpdateGraph(nodes);
EXPECT_EQ(ret, PARAM_INVALID);
}



TEST_F(UtestTensorflowParser, parser_RebuildFusionNode_test) TEST_F(UtestTensorflowParser, parser_RebuildFusionNode_test)
{ {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
@@ -4572,7 +4644,7 @@ TEST_F(UtestTensorflowParser, tensorflow_SoftmaxAddAttr)


TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats) TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats)
{ {
domiTensorFormat_t ret;
domiTensorFormat_t ret2;
TensorFlowModelParser modelParser; TensorFlowModelParser modelParser;


GetParserContext().format = DOMI_TENSOR_RESERVED; GetParserContext().format = DOMI_TENSOR_RESERVED;
@@ -4580,15 +4652,38 @@ TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats)
NodeDef *node = MallocNodeDef("node", "DATA"); NodeDef *node = MallocNodeDef("node", "DATA");
modelParser.nodedef_map_["node"] = node; modelParser.nodedef_map_["node"] = node;
tensorflow_op_map["DATA"] = "node"; tensorflow_op_map["DATA"] = "node";
ret = modelParser.InferInputFormats();
EXPECT_EQ(ret, domi::DOMI_TENSOR_NHWC);
ret2 = modelParser.InferInputFormats();
EXPECT_EQ(ret2, domi::DOMI_TENSOR_NHWC);
delete node; delete node;
NodeDef* node1 = nullptr; NodeDef* node1 = nullptr;
modelParser.nodedef_map_["node"] = node1; modelParser.nodedef_map_["node"] = node1;


ret = modelParser.InferInputFormats();
EXPECT_EQ(ret, domi::DOMI_TENSOR_RESERVED);
ret2 = modelParser.InferInputFormats();
EXPECT_EQ(ret2, domi::DOMI_TENSOR_RESERVED);

char *data = nullptr;
uint32_t size = 0;
ge::Graph graph;
Status ret = modelParser.ParseFromMemory(data, size, graph);
EXPECT_EQ(ret, SUCCESS);

string file = "./";
ret = modelParser.Save(file);
EXPECT_NE(ret, SUCCESS);

bool ret1 = modelParser.HasError();
EXPECT_EQ(ret1, SUCCESS);
modelParser.Clear();

TensorFlowWeightsParser tensorflow_weights_parser;
string file_path = "./";
ret = tensorflow_weights_parser.Save(file_path);
EXPECT_NE(ret, SUCCESS);

ret1 = tensorflow_weights_parser.HasError();
EXPECT_EQ(ret1, SUCCESS);
tensorflow_weights_parser.Clear();
} }


TEST_F(UtestTensorflowParser, tensorflow_GetTransposeInfo) TEST_F(UtestTensorflowParser, tensorflow_GetTransposeInfo)


Loading…
Cancel
Save