Browse Source

tensorflow/parser st

pull/419/head
jwx930962 4 years ago
parent
commit
3b22eb4823
1 changed files with 88 additions and 15 deletions
  1. +88
    -15
      tests/st/testcase/test_tensorflow_parser.cc

+ 88
- 15
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -322,8 +322,8 @@ namespace {
} }


NodeDef *fusioninitNodeDef(int index) { NodeDef *fusioninitNodeDef(int index) {
NodeDef * nodeDef = new NodeDef();
::google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = nodeDef->mutable_attr();
NodeDef *nodeDef = new NodeDef();
google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef->mutable_attr();


//设置 type属性 //设置 type属性
domi::tensorflow::AttrValue dtype_attr_value ; domi::tensorflow::AttrValue dtype_attr_value ;
@@ -1205,8 +1205,6 @@ TEST_F(STestTensorflowParser, parse_AutoMappingByOp) {
static const float VALUE_FLOAT = 1.0; static const float VALUE_FLOAT = 1.0;
static const bool VALUE_BOOL = true; static const bool VALUE_BOOL = true;
static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;

std::cout << "test data_type value_type: " << (int64_t)VALUE_TYPE << std::endl;
static const string VALUE_NAME = "test_name"; static const string VALUE_NAME = "test_name";
ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(); ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>();
NodeDef node_def; NodeDef node_def;
@@ -2078,7 +2076,7 @@ TEST_F(STestTensorflowParser, tensorflow_arg_parser_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test)
TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test1)
{ {
TensorFlowCustomParserAdapter parser; TensorFlowCustomParserAdapter parser;
ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>(); ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
@@ -2095,6 +2093,66 @@ TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test)
EXPECT_EQ(ret, PARAM_INVALID); EXPECT_EQ(ret, PARAM_INVALID);
} }


TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test2)
{
TensorFlowCustomParserAdapter parser;
ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
NodeDef *node_def = initNodeDef();
node_def->set_name("FrameworkOp");
node_def->set_op("_Retval");
TensorFlowModelParser modelParser;
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("FrameworkOp");
shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
static const string KEY_SHAPE_LIST = "key_shape_list";
static const string KEY_TENSOR_LIST = "key_tensor_list";
static const string KEY_DEFAULT = "key_default";

google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr();
domi::tensorflow::AttrValue dtype_attr_value;
dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
(*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;

//设置strides属性
domi::tensorflow::AttrValue axis_attr_value;
::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
list->add_i(1);
list->add_i(2);
(*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;

domi::tensorflow::AttrValue value;
domi::tensorflow::AttrValue df_attr_value;
df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);

domi::tensorflow::AttrValue pad_attr_value;
pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);

domi::tensorflow::AttrValue shape;
shape.mutable_list()->add_i((int64)32);
shape.mutable_list()->add_i((int64)32);
shape.mutable_list()->add_i((int64)14);
static const string KEY_TYPE_LIST = "key_type_list";
const std::string ATTR_NAME_INPUT_TENSOR_DESC = "ATTR_NAME_FRAMEWORK_OP_DEF";
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
value.clear_value();
value.mutable_list()->add_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def);

value.clear_value();
domi::tensorflow::NameAttrList name_attr_list;
name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
*(value.mutable_list()->add_func()) = name_attr_list;

node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value});
node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value});
Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test) TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test)
{ {
TensorFlowCustomParserAdapter parser; TensorFlowCustomParserAdapter parser;
@@ -2967,16 +3025,10 @@ TEST_F(STestTensorflowParser, tensorflow_ParserNodeDef2_test)
TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test) TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test)
{ {
TensorFlowModelParser modelParser; TensorFlowModelParser modelParser;
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
std::string modelFile = caseDir + "/origin_models/tf_add.pb";
ge::Graph graph;
std::map<ge::AscendString, ge::AscendString> parser_params = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
auto ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
ret = modelParser.AddExternalGraph(compute_graph);
ge::ComputeGraphPtr subGraph = std::make_shared<ge::ComputeGraph>("default");
std::string inputNodeType = "DATA";
MakeDagGraph(subGraph, inputNodeType);
Status ret = modelParser.AddExternalGraph(subGraph);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


@@ -3023,4 +3075,25 @@ TEST_F(STestTensorflowParser, tensorflow_ParseOpParams_test)
delete node_def; delete node_def;
} }


TEST_F(STestTensorflowParser, tensorflow_AddFusionInnerNodeDef_test)
{
TensorFlowModelParser model_parser;
ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef();
ScopePassManager pass_manager;
std::shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graphDef);
std::vector<std::string> op_node_name_list = {"Const", "placeholder0"};
FusionScopesResult *fusion_scope_rlt = new (std::nothrow) FusionScopesResult();
fusion_scope_rlt->Init();
fusion_scope_rlt->SetName("FusionCustom");
auto &impl_scope_graph = scope_graph->impl_;
std::string scope_name = fusion_scope_rlt->Name();
impl_scope_graph->fusion_results_.insert(std::make_pair(scope_name, fusion_scope_rlt));
std::string fusion_op_name = "FusionCustom";
GenOriginNodeDef(&model_parser, op_node_name_list);
GenFusionScopesResult(scope_graph, fusion_scope_rlt, fusion_op_name);
Status ret = model_parser.AddFusionInnerNodeDef(scope_graph, fusion_op_name, op_node_name_list);
delete graphDef;
}

} // namespace ge } // namespace ge

Loading…
Cancel
Save