| @@ -322,8 +322,8 @@ namespace { | |||
| } | |||
| NodeDef *fusioninitNodeDef(int index) { | |||
| NodeDef * nodeDef = new NodeDef(); | |||
| ::google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = nodeDef->mutable_attr(); | |||
| NodeDef *nodeDef = new NodeDef(); | |||
| google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef->mutable_attr(); | |||
| //设置 type属性 | |||
| domi::tensorflow::AttrValue dtype_attr_value ; | |||
| @@ -1205,8 +1205,6 @@ TEST_F(STestTensorflowParser, parse_AutoMappingByOp) { | |||
| static const float VALUE_FLOAT = 1.0; | |||
| static const bool VALUE_BOOL = true; | |||
| static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; | |||
| std::cout << "test data_type value_type: " << (int64_t)VALUE_TYPE << std::endl; | |||
| static const string VALUE_NAME = "test_name"; | |||
| ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(); | |||
| NodeDef node_def; | |||
| @@ -2078,7 +2076,7 @@ TEST_F(STestTensorflowParser, tensorflow_arg_parser_test) | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test) | |||
| TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test1) | |||
| { | |||
| TensorFlowCustomParserAdapter parser; | |||
| ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>(); | |||
| @@ -2095,6 +2093,66 @@ TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test) | |||
| EXPECT_EQ(ret, PARAM_INVALID); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test2) | |||
| { | |||
| TensorFlowCustomParserAdapter parser; | |||
| ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>(); | |||
| NodeDef *node_def = initNodeDef(); | |||
| node_def->set_name("FrameworkOp"); | |||
| node_def->set_op("_Retval"); | |||
| TensorFlowModelParser modelParser; | |||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | |||
| std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("FrameworkOp"); | |||
| shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser); | |||
| static const string KEY_SHAPE_LIST = "key_shape_list"; | |||
| static const string KEY_TENSOR_LIST = "key_tensor_list"; | |||
| static const string KEY_DEFAULT = "key_default"; | |||
| google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr(); | |||
| domi::tensorflow::AttrValue dtype_attr_value; | |||
| dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT); | |||
| (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value; | |||
| //设置strides属性 | |||
| domi::tensorflow::AttrValue axis_attr_value; | |||
| ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list(); | |||
| list->add_i(1); | |||
| list->add_i(2); | |||
| (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value; | |||
| domi::tensorflow::AttrValue value; | |||
| domi::tensorflow::AttrValue df_attr_value; | |||
| df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC); | |||
| domi::tensorflow::AttrValue pad_attr_value; | |||
| pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT); | |||
| domi::tensorflow::AttrValue shape; | |||
| shape.mutable_list()->add_i((int64)32); | |||
| shape.mutable_list()->add_i((int64)32); | |||
| shape.mutable_list()->add_i((int64)14); | |||
| static const string KEY_TYPE_LIST = "key_type_list"; | |||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "ATTR_NAME_FRAMEWORK_OP_DEF"; | |||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||
| static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; | |||
| value.clear_value(); | |||
| value.mutable_list()->add_type(VALUE_TYPE); | |||
| TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def); | |||
| value.clear_value(); | |||
| domi::tensorflow::NameAttrList name_attr_list; | |||
| name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value}); | |||
| name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value}); | |||
| name_attr_list.mutable_attr()->insert({"serialize_shape", shape}); | |||
| *(value.mutable_list()->add_func()) = name_attr_list; | |||
| node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value}); | |||
| node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value}); | |||
| Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test) | |||
| { | |||
| TensorFlowCustomParserAdapter parser; | |||
| @@ -2967,16 +3025,10 @@ TEST_F(STestTensorflowParser, tensorflow_ParserNodeDef2_test) | |||
| TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test) | |||
| { | |||
| TensorFlowModelParser modelParser; | |||
| std::string caseDir = __FILE__; | |||
| std::size_t idx = caseDir.find_last_of("/"); | |||
| caseDir = caseDir.substr(0, idx); | |||
| std::string modelFile = caseDir + "/origin_models/tf_add.pb"; | |||
| ge::Graph graph; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; | |||
| auto ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | |||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| ret = modelParser.AddExternalGraph(compute_graph); | |||
| ge::ComputeGraphPtr subGraph = std::make_shared<ge::ComputeGraph>("default"); | |||
| std::string inputNodeType = "DATA"; | |||
| MakeDagGraph(subGraph, inputNodeType); | |||
| Status ret = modelParser.AddExternalGraph(subGraph); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| @@ -3023,4 +3075,25 @@ TEST_F(STestTensorflowParser, tensorflow_ParseOpParams_test) | |||
| delete node_def; | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_AddFusionInnerNodeDef_test) | |||
| { | |||
| TensorFlowModelParser model_parser; | |||
| ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); | |||
| tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef(); | |||
| ScopePassManager pass_manager; | |||
| std::shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graphDef); | |||
| std::vector<std::string> op_node_name_list = {"Const", "placeholder0"}; | |||
| FusionScopesResult *fusion_scope_rlt = new (std::nothrow) FusionScopesResult(); | |||
| fusion_scope_rlt->Init(); | |||
| fusion_scope_rlt->SetName("FusionCustom"); | |||
| auto &impl_scope_graph = scope_graph->impl_; | |||
| std::string scope_name = fusion_scope_rlt->Name(); | |||
| impl_scope_graph->fusion_results_.insert(std::make_pair(scope_name, fusion_scope_rlt)); | |||
| std::string fusion_op_name = "FusionCustom"; | |||
| GenOriginNodeDef(&model_parser, op_node_name_list); | |||
| GenFusionScopesResult(scope_graph, fusion_scope_rlt, fusion_op_name); | |||
| Status ret = model_parser.AddFusionInnerNodeDef(scope_graph, fusion_op_name, op_node_name_list); | |||
| delete graphDef; | |||
| } | |||
| } // namespace ge | |||