|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "graph/passes/branch_logical_remove_pass.h"
- #include "graph/utils/node_utils.h"
-
- namespace ge {
- Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) {
- BranchExecCondCalculator calculator(graph);
- if (calculator.Calculate() != GRAPH_SUCCESS) {
- GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed.");
- return FAILED;
- }
- const auto &node_exec_cond = calculator.GetBranchExecCond();
- if (node_exec_cond.empty()) {
- GELOGI("No branch in graph %s, skip.", graph->GetName().c_str());
- return SUCCESS;
- }
-
- for (const auto &node : graph->GetDirectNode()) {
- const std::string &type = NodeUtils::GetNodeType(node);
- if ((type == SWITCH) || (type == REFSWITCH)) {
- if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) {
- GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str());
- return FAILED;
- }
- } else if ((type == MERGE) || (type == REFMERGE)) {
- if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) {
- GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str());
- return FAILED;
- }
- }
- }
-
- return SUCCESS;
- }
-
- Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
- const NodePtr &switch_node) {
- const auto &iter = node_exec_cond.find(switch_node);
- if (iter == node_exec_cond.end()) {
- GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str());
- return FAILED;
- }
- if (!iter->second.IsValid()) {
- return SUCCESS;
- }
- for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) {
- const auto &out_node = pair.first;
- const auto &out_iter = node_exec_cond.find(out_node);
- if (out_iter == node_exec_cond.end()) {
- GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str());
- return FAILED;
- }
- if (!out_iter->second.IsValid()) {
- GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(),
- pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
- if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(),
- pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
- return FAILED;
- }
- } else if (iter->second.String() == out_iter->second.String()) {
- GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(),
- pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
- if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(),
- pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
- return FAILED;
- }
-
- const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT);
- GE_CHECK_NOTNULL(data_input_anchor);
- const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor();
- GE_CHECK_NOTNULL(peer_out_anchor);
- GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
- pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
- if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
- pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
- return FAILED;
- }
- }
- }
-
- return SUCCESS;
- }
-
- Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
- const NodePtr &merge_node) {
- const auto &iter = node_exec_cond.find(merge_node);
- if (iter == node_exec_cond.end()) {
- GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str());
- return FAILED;
- }
- if (!iter->second.IsValid()) {
- return SUCCESS;
- }
- for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) {
- const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
- if (peer_out_anchor == nullptr) { continue; }
- const auto &in_node = peer_out_anchor->GetOwnerNode();
- const auto &in_iter = node_exec_cond.find(in_node);
- if (in_iter == node_exec_cond.end()) {
- GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str());
- return FAILED;
- }
- if (!in_iter->second.IsValid()) {
- GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(),
- merge_node->GetName().c_str(), in_data_anchor->GetIdx());
- if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) {
- GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(),
- merge_node->GetName().c_str(), in_data_anchor->GetIdx());
- return FAILED;
- }
- }
- }
- return SUCCESS;
- }
- } // namespace ge
|