/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/branch_logical_remove_pass.h" #include "graph/utils/node_utils.h" namespace ge { Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) { BranchExecCondCalculator calculator(graph); if (calculator.Calculate() != GRAPH_SUCCESS) { GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed."); return FAILED; } const auto &node_exec_cond = calculator.GetBranchExecCond(); if (node_exec_cond.empty()) { GELOGI("No branch in graph %s, skip.", graph->GetName().c_str()); return SUCCESS; } for (const auto &node : graph->GetDirectNode()) { const std::string &type = NodeUtils::GetNodeType(node); if ((type == SWITCH) || (type == REFSWITCH)) { if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) { GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str()); return FAILED; } } else if ((type == MERGE) || (type == REFMERGE)) { if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) { GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str()); return FAILED; } } } return SUCCESS; } Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map &node_exec_cond, const NodePtr &switch_node) { const auto &iter = node_exec_cond.find(switch_node); if (iter == node_exec_cond.end()) { GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str()); return FAILED; } if (!iter->second.IsValid()) { return SUCCESS; } for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) { const auto &out_node = pair.first; const auto &out_iter = node_exec_cond.find(out_node); if (out_iter == node_exec_cond.end()) { GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str()); return FAILED; } if (!out_iter->second.IsValid()) { GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); return FAILED; } } else if (iter->second.String() == out_iter->second.String()) { GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); return FAILED; } const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT); GE_CHECK_NOTNULL(data_input_anchor); const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); return FAILED; } } } return SUCCESS; } Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map &node_exec_cond, const NodePtr &merge_node) { const auto &iter = node_exec_cond.find(merge_node); if (iter == node_exec_cond.end()) { GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str()); return FAILED; } if (!iter->second.IsValid()) { return SUCCESS; } for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) { const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; } const auto &in_node = peer_out_anchor->GetOwnerNode(); const auto &in_iter = node_exec_cond.find(in_node); if (in_iter == node_exec_cond.end()) { GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str()); return FAILED; } if (!in_iter->second.IsValid()) { GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), merge_node->GetName().c_str(), in_data_anchor->GetIdx()); if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), merge_node->GetName().c_str(), in_data_anchor->GetIdx()); return FAILED; } } } return SUCCESS; } } // namespace ge