| @@ -322,8 +322,8 @@ namespace { | |||||
| } | } | ||||
| NodeDef *fusioninitNodeDef(int index) { | NodeDef *fusioninitNodeDef(int index) { | ||||
| NodeDef * nodeDef = new NodeDef(); | |||||
| ::google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = nodeDef->mutable_attr(); | |||||
| NodeDef *nodeDef = new NodeDef(); | |||||
| google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef->mutable_attr(); | |||||
| //设置 type属性 | //设置 type属性 | ||||
| domi::tensorflow::AttrValue dtype_attr_value ; | domi::tensorflow::AttrValue dtype_attr_value ; | ||||
| @@ -1205,8 +1205,6 @@ TEST_F(STestTensorflowParser, parse_AutoMappingByOp) { | |||||
| static const float VALUE_FLOAT = 1.0; | static const float VALUE_FLOAT = 1.0; | ||||
| static const bool VALUE_BOOL = true; | static const bool VALUE_BOOL = true; | ||||
| static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; | static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; | ||||
| std::cout << "test data_type value_type: " << (int64_t)VALUE_TYPE << std::endl; | |||||
| static const string VALUE_NAME = "test_name"; | static const string VALUE_NAME = "test_name"; | ||||
| ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(); | ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(); | ||||
| NodeDef node_def; | NodeDef node_def; | ||||
| @@ -2078,7 +2076,7 @@ TEST_F(STestTensorflowParser, tensorflow_arg_parser_test) | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| } | } | ||||
| TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test) | |||||
| TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test1) | |||||
| { | { | ||||
| TensorFlowCustomParserAdapter parser; | TensorFlowCustomParserAdapter parser; | ||||
| ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>(); | ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>(); | ||||
| @@ -2095,6 +2093,66 @@ TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test) | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | EXPECT_EQ(ret, PARAM_INVALID); | ||||
| } | } | ||||
| TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test2) | |||||
| { | |||||
| TensorFlowCustomParserAdapter parser; | |||||
| ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>(); | |||||
| NodeDef *node_def = initNodeDef(); | |||||
| node_def->set_name("FrameworkOp"); | |||||
| node_def->set_op("_Retval"); | |||||
| TensorFlowModelParser modelParser; | |||||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | |||||
| std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("FrameworkOp"); | |||||
| shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser); | |||||
| static const string KEY_SHAPE_LIST = "key_shape_list"; | |||||
| static const string KEY_TENSOR_LIST = "key_tensor_list"; | |||||
| static const string KEY_DEFAULT = "key_default"; | |||||
| google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr(); | |||||
| domi::tensorflow::AttrValue dtype_attr_value; | |||||
| dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT); | |||||
| (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value; | |||||
| //设置strides属性 | |||||
| domi::tensorflow::AttrValue axis_attr_value; | |||||
| ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list(); | |||||
| list->add_i(1); | |||||
| list->add_i(2); | |||||
| (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value; | |||||
| domi::tensorflow::AttrValue value; | |||||
| domi::tensorflow::AttrValue df_attr_value; | |||||
| df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC); | |||||
| domi::tensorflow::AttrValue pad_attr_value; | |||||
| pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT); | |||||
| domi::tensorflow::AttrValue shape; | |||||
| shape.mutable_list()->add_i((int64)32); | |||||
| shape.mutable_list()->add_i((int64)32); | |||||
| shape.mutable_list()->add_i((int64)14); | |||||
| static const string KEY_TYPE_LIST = "key_type_list"; | |||||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "ATTR_NAME_FRAMEWORK_OP_DEF"; | |||||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||||
| static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; | |||||
| value.clear_value(); | |||||
| value.mutable_list()->add_type(VALUE_TYPE); | |||||
| TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def); | |||||
| value.clear_value(); | |||||
| domi::tensorflow::NameAttrList name_attr_list; | |||||
| name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value}); | |||||
| name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value}); | |||||
| name_attr_list.mutable_attr()->insert({"serialize_shape", shape}); | |||||
| *(value.mutable_list()->add_func()) = name_attr_list; | |||||
| node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value}); | |||||
| node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value}); | |||||
| Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test) | TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test) | ||||
| { | { | ||||
| TensorFlowCustomParserAdapter parser; | TensorFlowCustomParserAdapter parser; | ||||
| @@ -2967,16 +3025,10 @@ TEST_F(STestTensorflowParser, tensorflow_ParserNodeDef2_test) | |||||
| TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test) | TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test) | ||||
| { | { | ||||
| TensorFlowModelParser modelParser; | TensorFlowModelParser modelParser; | ||||
| std::string caseDir = __FILE__; | |||||
| std::size_t idx = caseDir.find_last_of("/"); | |||||
| caseDir = caseDir.substr(0, idx); | |||||
| std::string modelFile = caseDir + "/origin_models/tf_add.pb"; | |||||
| ge::Graph graph; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; | |||||
| auto ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | |||||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||||
| ret = modelParser.AddExternalGraph(compute_graph); | |||||
| ge::ComputeGraphPtr subGraph = std::make_shared<ge::ComputeGraph>("default"); | |||||
| std::string inputNodeType = "DATA"; | |||||
| MakeDagGraph(subGraph, inputNodeType); | |||||
| Status ret = modelParser.AddExternalGraph(subGraph); | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| } | } | ||||
| @@ -3023,4 +3075,25 @@ TEST_F(STestTensorflowParser, tensorflow_ParseOpParams_test) | |||||
| delete node_def; | delete node_def; | ||||
| } | } | ||||
| TEST_F(STestTensorflowParser, tensorflow_AddFusionInnerNodeDef_test) | |||||
| { | |||||
| TensorFlowModelParser model_parser; | |||||
| ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); | |||||
| tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef(); | |||||
| ScopePassManager pass_manager; | |||||
| std::shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graphDef); | |||||
| std::vector<std::string> op_node_name_list = {"Const", "placeholder0"}; | |||||
| FusionScopesResult *fusion_scope_rlt = new (std::nothrow) FusionScopesResult(); | |||||
| fusion_scope_rlt->Init(); | |||||
| fusion_scope_rlt->SetName("FusionCustom"); | |||||
| auto &impl_scope_graph = scope_graph->impl_; | |||||
| std::string scope_name = fusion_scope_rlt->Name(); | |||||
| impl_scope_graph->fusion_results_.insert(std::make_pair(scope_name, fusion_scope_rlt)); | |||||
| std::string fusion_op_name = "FusionCustom"; | |||||
| GenOriginNodeDef(&model_parser, op_node_name_list); | |||||
| GenFusionScopesResult(scope_graph, fusion_scope_rlt, fusion_op_name); | |||||
| Status ret = model_parser.AddFusionInnerNodeDef(scope_graph, fusion_op_name, op_node_name_list); | |||||
| delete graphDef; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||