|
|
@@ -43,7 +43,7 @@ const char *const kAttrNameDstFormat = "dst_format"; |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
void TransOpWithoutReshapeFusionPass::SetRemainNode( |
|
|
|
const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) { |
|
|
|
const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) { |
|
|
|
auto iter = nodes_anchor.begin(); |
|
|
|
while (iter != nodes_anchor.end()) { |
|
|
|
auto in_anchor = iter->second; |
|
|
@@ -63,7 +63,8 @@ void TransOpWithoutReshapeFusionPass::SetRemainNode( |
|
|
|
if (op_desc == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return); |
|
|
|
GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str()); |
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return ); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -158,7 +159,7 @@ graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() { |
|
|
|
} |
|
|
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors( |
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) { |
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) { |
|
|
|
// The caller guarantees that the index is legal. |
|
|
|
for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) { |
|
|
|
auto nodes_anchor = sub_graph_anchors_[index][j]; |
|
|
@@ -181,9 +182,9 @@ void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors( |
|
|
|
} |
|
|
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors( |
|
|
|
const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) { |
|
|
|
const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) { |
|
|
|
// The caller guarantees that the index is legal. |
|
|
|
for (size_t j = 1; j < sub_graph_nodes_[index].size(); ++j) { |
|
|
|
for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) { |
|
|
|
auto node = sub_graph_nodes_[index][j]; |
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(node); |
|
|
|
auto in_control_anchor = node->GetInControlAnchor(); |
|
|
@@ -208,8 +209,8 @@ void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors( |
|
|
|
} |
|
|
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors( |
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors, |
|
|
|
vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) { |
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors, |
|
|
|
vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) { |
|
|
|
for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) { |
|
|
|
auto node = sub_graph_nodes_[index][j]; |
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(node); |
|
|
@@ -335,8 +336,8 @@ void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &ol |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges( |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair, |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) { |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair, |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) { |
|
|
|
auto out_anchor = begin_anchors_pair.first; |
|
|
|
GE_CHECK_NOTNULL(out_anchor); |
|
|
|
auto out_owner_node = out_anchor->GetOwnerNode(); |
|
|
@@ -364,8 +365,8 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges( |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged( |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair, |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) { |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair, |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) { |
|
|
|
if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) { |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
@@ -418,8 +419,8 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChange |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged( |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair, |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) { |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair, |
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) { |
|
|
|
auto out_anchor = begin_anchors_pair.first; |
|
|
|
GE_CHECK_NOTNULL(out_anchor); |
|
|
|
auto out_owner_node = out_anchor->GetOwnerNode(); |
|
|
@@ -581,7 +582,7 @@ void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int inde |
|
|
|
auto out_owner_node = out_peer_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node); |
|
|
|
auto out_peer_op_desc = out_owner_node->GetOpDesc(); |
|
|
|
GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return); |
|
|
|
GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return ); |
|
|
|
out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx()); |
|
|
|
|
|
|
|
auto in_peer_anchor = nodes_anchor.back().first; |
|
|
@@ -589,7 +590,7 @@ void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int inde |
|
|
|
auto in_owner_node = in_peer_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node); |
|
|
|
auto in_peer_op_desc = in_owner_node->GetOpDesc(); |
|
|
|
GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return); |
|
|
|
GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return ); |
|
|
|
in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx()); |
|
|
|
} |
|
|
|
|
|
|
@@ -721,7 +722,7 @@ void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &g |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return); |
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return ); |
|
|
|
GELOGI("remove node:%s", node->GetName().c_str()); |
|
|
|
if (graph->RemoveNode(node) != GRAPH_SUCCESS) { |
|
|
|
GELOGW("remove node failed!node:%s", node->GetName().c_str()); |
|
|
@@ -743,7 +744,7 @@ graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { |
|
|
|
if (IsTransOp(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
GELOGD("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str()); |
|
|
|
GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str()); |
|
|
|
for (const auto &out_anchor : node->GetAllOutDataAnchors()) { |
|
|
|
GE_CHECK_NOTNULL(out_anchor); |
|
|
|
vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors; |
|
|
@@ -887,11 +888,6 @@ graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr |
|
|
|
new_trans_nodes.push_back(cast_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (new_trans_nodes.empty()) { |
|
|
|
GELOGE(GRAPH_FAILED, "no new transop!this should not happen!"); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
@@ -902,6 +898,10 @@ graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraph |
|
|
|
if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) { |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
if (new_trans_nodes.empty()) { |
|
|
|
GELOGI("No new trans node. Do not need insert new transop."); |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
pair<OutDataAnchorPtr, InDataAnchorPtr> begin_out = sub_graph_anchors_[index].front(); |
|
|
|
pair<OutDataAnchorPtr, InDataAnchorPtr> end_in = sub_graph_anchors_[index].back(); |
|
|
@@ -1051,9 +1051,8 @@ bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) { |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode( |
|
|
|
const OutDataAnchorPtr &out_anchor, |
|
|
|
std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out, |
|
|
|
vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) { |
|
|
|
const OutDataAnchorPtr &out_anchor, std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out, |
|
|
|
vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) { |
|
|
|
graphStatus ret = GRAPH_SUCCESS; |
|
|
|
if (out_anchor == nullptr) { |
|
|
|
return GRAPH_FAILED; |
|
|
|