From 3b22eb482359c0982fb491aae968e196d7b13cef Mon Sep 17 00:00:00 2001 From: jwx930962 Date: Wed, 8 Dec 2021 15:37:58 +0800 Subject: [PATCH] tensorflow/parser st --- tests/st/testcase/test_tensorflow_parser.cc | 103 +++++++++++++++++--- 1 file changed, 88 insertions(+), 15 deletions(-) diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 22dddc8..cd96c77 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -322,8 +322,8 @@ namespace { } NodeDef *fusioninitNodeDef(int index) { - NodeDef * nodeDef = new NodeDef(); - ::google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = nodeDef->mutable_attr(); + NodeDef *nodeDef = new NodeDef(); + google::protobuf::Map *node_attr_map = nodeDef->mutable_attr(); //设置 type属性 domi::tensorflow::AttrValue dtype_attr_value ; @@ -1205,8 +1205,6 @@ TEST_F(STestTensorflowParser, parse_AutoMappingByOp) { static const float VALUE_FLOAT = 1.0; static const bool VALUE_BOOL = true; static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; - - std::cout << "test data_type value_type: " << (int64_t)VALUE_TYPE << std::endl; static const string VALUE_NAME = "test_name"; ge::OpDescPtr op_desc = std::make_shared(); NodeDef node_def; @@ -2078,7 +2076,7 @@ TEST_F(STestTensorflowParser, tensorflow_arg_parser_test) EXPECT_EQ(ret, SUCCESS); } -TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test) +TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test1) { TensorFlowCustomParserAdapter parser; ge::OpDescPtr op_dest = std::make_shared(); @@ -2095,6 +2093,66 @@ TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test) EXPECT_EQ(ret, PARAM_INVALID); } +TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test2) +{ + TensorFlowCustomParserAdapter parser; + ge::OpDescPtr op_dest = std::make_shared(); + NodeDef *node_def = initNodeDef(); + node_def->set_name("FrameworkOp"); + node_def->set_op("_Retval"); + TensorFlowModelParser modelParser; + std::shared_ptr factory = OpParserFactory::Instance(domi::TENSORFLOW); + std::shared_ptr op_parser = factory->CreateOpParser("FrameworkOp"); + shared_ptr tensorflow_op_parser = std::dynamic_pointer_cast(op_parser); + static const string KEY_SHAPE_LIST = "key_shape_list"; + static const string KEY_TENSOR_LIST = "key_tensor_list"; + static const string KEY_DEFAULT = "key_default"; + + google::protobuf::Map *node_attr_map = node_def->mutable_attr(); + domi::tensorflow::AttrValue dtype_attr_value; + dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT); + (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value; + + //设置strides属性 + domi::tensorflow::AttrValue axis_attr_value; + ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list(); + list->add_i(1); + list->add_i(2); + (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value; + + domi::tensorflow::AttrValue value; + domi::tensorflow::AttrValue df_attr_value; + df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC); + + domi::tensorflow::AttrValue pad_attr_value; + pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT); + + domi::tensorflow::AttrValue shape; + shape.mutable_list()->add_i((int64)32); + shape.mutable_list()->add_i((int64)32); + shape.mutable_list()->add_i((int64)14); + + static const string KEY_TYPE_LIST = "key_type_list"; + const std::string ATTR_NAME_INPUT_TENSOR_DESC = "ATTR_NAME_FRAMEWORK_OP_DEF"; + const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; + static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; + value.clear_value(); + value.mutable_list()->add_type(VALUE_TYPE); + TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def); + + value.clear_value(); + domi::tensorflow::NameAttrList name_attr_list; + name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value}); + name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value}); + name_attr_list.mutable_attr()->insert({"serialize_shape", shape}); + *(value.mutable_list()->add_func()) = name_attr_list; + + node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value}); + node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value}); + Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest); + EXPECT_EQ(ret, SUCCESS); +} + TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test) { TensorFlowCustomParserAdapter parser; @@ -2967,16 +3025,10 @@ TEST_F(STestTensorflowParser, tensorflow_ParserNodeDef2_test) TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test) { TensorFlowModelParser modelParser; - std::string caseDir = __FILE__; - std::size_t idx = caseDir.find_last_of("/"); - caseDir = caseDir.substr(0, idx); - std::string modelFile = caseDir + "/origin_models/tf_add.pb"; - ge::Graph graph; - std::map parser_params = { - {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; - auto ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); - ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); - ret = modelParser.AddExternalGraph(compute_graph); + ge::ComputeGraphPtr subGraph = std::make_shared("default"); + std::string inputNodeType = "DATA"; + MakeDagGraph(subGraph, inputNodeType); + Status ret = modelParser.AddExternalGraph(subGraph); EXPECT_EQ(ret, SUCCESS); } @@ -3023,4 +3075,25 @@ TEST_F(STestTensorflowParser, tensorflow_ParseOpParams_test) delete node_def; } +TEST_F(STestTensorflowParser, tensorflow_AddFusionInnerNodeDef_test) +{ + TensorFlowModelParser model_parser; + ge::ComputeGraphPtr compute_graph = std::make_shared(GRAPH_DEFAULT_NAME); + tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef(); + ScopePassManager pass_manager; + std::shared_ptr scope_graph = pass_manager.BuildScopeGraph(graphDef); + std::vector op_node_name_list = {"Const", "placeholder0"}; + FusionScopesResult *fusion_scope_rlt = new (std::nothrow) FusionScopesResult(); + fusion_scope_rlt->Init(); + fusion_scope_rlt->SetName("FusionCustom"); + auto &impl_scope_graph = scope_graph->impl_; + std::string scope_name = fusion_scope_rlt->Name(); + impl_scope_graph->fusion_results_.insert(std::make_pair(scope_name, fusion_scope_rlt)); + std::string fusion_op_name = "FusionCustom"; + GenOriginNodeDef(&model_parser, op_node_name_list); + GenFusionScopesResult(scope_graph, fusion_scope_rlt, fusion_op_name); + Status ret = model_parser.AddFusionInnerNodeDef(scope_graph, fusion_op_name, op_node_name_list); + delete graphDef; +} + } // namespace ge