Browse Source

!535 sync ge_dev to master 20220428

Merge pull request !535 from yangyongqiang/ge_dev
pull/519/MERGE
计晨 Gitee 3 years ago
parent
commit
49bd601fde
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 19 additions and 197 deletions
  1. +3
    -1
      OWNERS
  2. +1
    -1
      metadef
  3. +13
    -7
      parser/onnx/onnx_parser.cc
  4. +1
    -83
      parser/tensorflow/tensorflow_parser.cc
  5. +1
    -22
      parser/tensorflow/tensorflow_parser.h
  6. +0
    -42
      tests/st/testcase/test_tensorflow_parser.cc
  7. +0
    -41
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 3
- 1
OWNERS View File

@@ -8,4 +8,6 @@ approvers:
reviewers:
- xchu42
- sheng-nan
- tangqunzhang
- tangqunzhang
- wangxiaotian22
- stevenaw

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit b3374c154d01a34e7173cd982c8eb46158f790aa
Subproject commit 7f1f5c49e3802219a1d6c4b874b0b553a7370220

+ 13
- 7
parser/onnx/onnx_parser.cc View File

@@ -594,6 +594,7 @@ Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::
}

Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {
bool has_error = false;
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
std::string node_name = node_proto->name();
@@ -605,7 +606,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (status != SUCCESS) {
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str());
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str());
return status;
has_error = true;
continue;
}
node_proto->set_op_type(ori_type);

@@ -616,7 +618,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (status != SUCCESS) {
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
return status;
has_error = true;
continue;
}

// 7. op parser
@@ -627,7 +630,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
status = ParseOpParam(node_proto, op, op_parser);
if (status != SUCCESS) {
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status);
return status;
has_error = true;
continue;
}

GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu",
@@ -638,7 +642,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str());
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str());
return FAILED;
has_error = true;
continue;
}
name_operator_[ParserUtils::GetOperatorName(op)] = op;

@@ -647,11 +652,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (status != SUCCESS) {
REPORT_INNER_ERROR("E19999", "ConstructInputOutputContext failed.");
GELOGE(status, "[Construct][RelationMap] to input and output failed.");
return status;
has_error = true;
continue;
}
}
GELOGI("Parse all node proto success.");
return SUCCESS;
GELOGI("Parse all node proto end.");
return has_error ? FAILED : SUCCESS;
}

Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) {


+ 1
- 83
parser/tensorflow/tensorflow_parser.cc View File

@@ -2470,82 +2470,6 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro
return SUCCESS;
}

// For the identity operator whose output is "_retval", optimize it.
Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map,
const string &curr_node_name, bool &clear_input_flag) {
auto context_iter = op_node_context_map_.find(curr_node_name);
if (context_iter == op_node_context_map_.end()) {
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
GELOGE(FAILED, "Can't find op node context.");
return INTERNAL_ERROR;
}
OpNodeContext op_node_context = context_iter->second;

const std::map<std::string, NodeDef *>::const_iterator node_def_iter = nodedef_map.find(curr_node_name);
if (node_def_iter == nodedef_map.cend()) {
REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str());
GELOGE(FAILED, "Can't find nodedef");
return INTERNAL_ERROR;
}
domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second;
GE_CHECK_NOTNULL(curr_node_def);
bool has_out_retval = false;
// For the identity operator whose output is "_retval", optimize it
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map;
for (auto output_iter = output_map.cbegin(); output_iter != output_map.cend(); ++output_iter) {
const string &output_node_name = output_iter->first;
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
GE_CHECK_NOTNULL(output_node_def);
if (output_node_def->op() == "_Retval") {
GELOGD("_Retval Identity need optimize.");
output_node_def->set_input(0, curr_node_def->input(0).c_str());
has_out_retval = true;
GELOGD("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str());
}
}

// Deal with non _Retval output operator of Identity.
if (has_out_retval) {
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>>::const_iterator output_iter = output_map.begin();
for (; output_iter != output_map.end(); ++output_iter) {
const string &output_node_name = output_iter->first;
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
GE_CHECK_NOTNULL(output_node_def);
GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue);
for (int k = 0; k < output_node_def->input_size(); ++k) {
GE_IF_BOOL_EXEC(
output_node_def->input(k) == curr_node_name, output_node_def->set_input(k, curr_node_def->input(0).c_str());
GELOGD("%s op set input(%d):%s.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str());)
}
}
clear_input_flag = true;
}
return SUCCESS;
}

Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def,
map<string, NodeDef *> &nodedef_map,
const vector<NodeDef *> &nodedef_to_optimize) {
GE_CHECK_NOTNULL(graph_def);
if (!nodedef_to_optimize.empty()) {
// Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
} else {
return SUCCESS;
}
for (auto &curr_node_def : nodedef_to_optimize) {
GE_CHECK_NOTNULL(curr_node_def);
bool clear_input_flag = false;
const string &curr_node_name = curr_node_def->name();
GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag));
if (clear_input_flag) {
curr_node_def->clear_input();
}
}
GELOGI("GraphDefOptimizeIdentity success.");
return SUCCESS;
}

Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def,
map<string, NodeDef *> &nodedef_map,
const std::pair<string, int> &input_data,
@@ -2861,8 +2785,6 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph
GE_CHECK_NOTNULL(graph_def);
map<string, NodeDef *> nodedef_map;
vector<string> op_node_name_list;
// Save Identity and ReadVariableOp
vector<NodeDef *> identity_to_optimize;
// Save Snapshot
vector<NodeDef *> snapshot_to_optimize;

@@ -2872,16 +2794,12 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph
const string &node_name = node_def->name();
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) {
identity_to_optimize.push_back(node_def);
} else if (node_def->op() == ge::parser::SNAPSHOT) {
if (node_def->op() == ge::parser::SNAPSHOT) {
snapshot_to_optimize.push_back(node_def);
}
nodedef_map[node_name] = node_def;
}

// Optimize for Identity/ReadVariableOp
GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize));
// Optimize for Snapshot
GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize));



+ 1
- 22
parser/tensorflow/tensorflow_parser.h View File

@@ -415,28 +415,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
* @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef
*/
Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def);
/**
* @ingroup domi_omg
* @brief Optimize for Identity/ReadVariableOp operator
* @param [in] graph_def GraphDef to be optimized
* @param [in] nodedef_map Map of all nodes in graph
* @param [in] nodedef_to_optimize vector of NodeDef to be optimized
* @return SUCCESS optimize successfully
* @return others failed
*/
Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
const vector<NodeDef *> &nodedef_to_optimize);
/**
* @ingroup domi_omg
* @brief For the identity operator whose output is "_retval", optimize it.
* @param [in] nodedef_map Map of all nodes in graph
* @param [in] curr_node_name Name of node to be optimized
* @param [in] clear_input_flag Flag of whether to clear the input of the current node
* @return SUCCESS optimize successfully
* @return others failed
*/
Status OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, const string &curr_node_name,
bool &clear_input_flag);

Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
const vector<NodeDef *> &nodedef_to_optimize);
Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def,


+ 0
- 42
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -2649,29 +2649,6 @@ TEST_F(STestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test)
model_parser.UpdateEdgesControlInfo(info);
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Placeholder");
node_def->set_op("Placeholder_0");
std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("Placeholder", node_def);
std::string curr_node_name = "Placeholder";
bool clear_input_flag = true;
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);

GraphDef graph;
curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_a", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, SUCCESS);
delete node_def;
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test)
{
TensorFlowModelParser model_parser;
@@ -2843,25 +2820,6 @@ TEST_F(STestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test)
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
{
tensorflow::GraphDef graph_def;
TensorFlowModelParser tensorflow_parser;
tensorflow::NodeDef *node_def = initNodeDef();
node_def->set_name("post_node_d");

std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("post_node_d", node_def);
nodedef_map.emplace("post_node_a", node_def);
nodedef_map.emplace("post_node_b", node_def);
std::vector<NodeDef *> nodedef_to_optimize;
nodedef_to_optimize.emplace_back(node_def);

std::string curr_node_name = "post_node_b";
GenOriginContext(&tensorflow_parser, curr_node_name);
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) {
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");


+ 0
- 41
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -2825,29 +2825,6 @@ TEST_F(UtestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test)
model_parser.UpdateEdgesControlInfo(info);
}

TEST_F(UtestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Placeholder");
node_def->set_op("Placeholder_0");
std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("Placeholder", node_def);
std::string curr_node_name = "Placeholder";
bool clear_input_flag = true;
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);

GraphDef graph;
curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_a", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, SUCCESS);
delete node_def;
}

TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test)
{
TensorFlowModelParser model_parser;
@@ -3019,25 +2996,7 @@ TEST_F(UtestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test)
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
{
tensorflow::GraphDef graph_def;
TensorFlowModelParser tensorflow_parser;
tensorflow::NodeDef *node_def = initNodeDef();
node_def->set_name("post_node_d");

std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("post_node_d", node_def);
nodedef_map.emplace("post_node_a", node_def);
nodedef_map.emplace("post_node_b", node_def);
std::vector<NodeDef *> nodedef_to_optimize;
nodedef_to_optimize.emplace_back(node_def);

std::string curr_node_name = "post_node_b";
GenOriginContext(&tensorflow_parser, curr_node_name);
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}
TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) {
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");


Loading…
Cancel
Save