Browse Source

!168 Bugfix: fix transpose fusion with input&output format check

Merge pull request !168 from zhaoxinxin/development
tags/v1.1.0
王涛 Gitee 4 years ago
parent
commit
4369336996
2 changed files with 17 additions and 0 deletions
  1. +8
    -0
      ge/graph/passes/transop_without_reshape_fusion_pass.cc
  2. +9
    -0
      ge/graph/passes/transpose_transdata_pass.cc

+ 8
- 0
ge/graph/passes/transop_without_reshape_fusion_pass.cc View File

@@ -130,6 +130,14 @@ graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
sub_graph_has_reshape_node[i] = true;
break;
}
if (in_node->GetType() == TRANSPOSE || in_node->GetType() == TRANSPOSED) {
auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat();
auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat();
if (input_format == output_format) {
sub_graph_has_reshape_node[i] = true;
break;
}
}

auto out_anchor = iter->first;
GE_CHECK_NOTNULL(out_anchor);


+ 9
- 0
ge/graph/passes/transpose_transdata_pass.cc View File

@@ -46,6 +46,15 @@ Status TransposeTransDataPass::Run(NodePtr &node) {
if (op_desc->GetType() != TRANSPOSED) {
return SUCCESS;
}
auto input_format = op_desc->GetInputDescPtr(0)->GetFormat();
auto output_format = op_desc->GetOutputDescPtr(0)->GetFormat();
if (input_format == output_format) {
GELOGW("Node %s input format is %s, output format is %s, should not happend. Ignore pass.",
op_desc->GetName().c_str(),
TypeUtils::FormatToSerialString(input_format).c_str(),
TypeUtils::FormatToSerialString(output_format).c_str());
return SUCCESS;
}
if (CheckOneInAndOneOutDataAnchor(node) != SUCCESS) {
return FAILED;
}


Loading…
Cancel
Save