From: @hugo1 Reviewed-by: @sheng-nan,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -270,7 +270,6 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/identity_pass.cc" | "graph/passes/identity_pass.cc" | ||||
"graph/passes/ref_identity_delete_op_pass.cc" | "graph/passes/ref_identity_delete_op_pass.cc" | ||||
"graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
"graph/passes/isolated_op_remove_pass.cc" | |||||
"graph/passes/iterator_op_pass.cc" | "graph/passes/iterator_op_pass.cc" | ||||
"graph/passes/link_gen_mask_nodes_pass.cc" | "graph/passes/link_gen_mask_nodes_pass.cc" | ||||
"graph/passes/merge_pass.cc" | "graph/passes/merge_pass.cc" | ||||
@@ -317,13 +316,11 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/transop_without_reshape_fusion_pass.cc" | "graph/passes/transop_without_reshape_fusion_pass.cc" | ||||
"graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
"graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
"graph/passes/unused_op_remove_pass.cc" | |||||
"graph/passes/var_is_initialized_op_pass.cc" | "graph/passes/var_is_initialized_op_pass.cc" | ||||
"graph/passes/parallel_concat_start_op_pass.cc" | "graph/passes/parallel_concat_start_op_pass.cc" | ||||
"graph/passes/cond_pass.cc" | "graph/passes/cond_pass.cc" | ||||
"graph/passes/cond_remove_pass.cc" | "graph/passes/cond_remove_pass.cc" | ||||
"graph/passes/for_pass.cc" | "graph/passes/for_pass.cc" | ||||
"graph/passes/variable_format_pass.cc" | |||||
"graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
"graph/passes/variable_prepare_op_pass.cc" | "graph/passes/variable_prepare_op_pass.cc" | ||||
"graph/passes/variable_ref_delete_op_pass.cc" | "graph/passes/variable_ref_delete_op_pass.cc" | ||||
@@ -522,12 +519,10 @@ set(INFER_SRC_LIST | |||||
"graph/passes/dimension_adjust_pass.cc" | "graph/passes/dimension_adjust_pass.cc" | ||||
"graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
"graph/passes/shape_operate_op_remove_pass.cc" | "graph/passes/shape_operate_op_remove_pass.cc" | ||||
"graph/passes/unused_op_remove_pass.cc" | |||||
"graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
"graph/passes/dropout_pass.cc" | "graph/passes/dropout_pass.cc" | ||||
"graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
"graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
"graph/passes/isolated_op_remove_pass.cc" | |||||
"graph/passes/permute_pass.cc" | "graph/passes/permute_pass.cc" | ||||
"graph/passes/ctrl_edge_transfer_pass.cc" | "graph/passes/ctrl_edge_transfer_pass.cc" | ||||
"graph/passes/end_of_sequence_add_control_pass.cc" | "graph/passes/end_of_sequence_add_control_pass.cc" | ||||
@@ -610,7 +605,6 @@ set(INFER_SRC_LIST | |||||
"graph/passes/switch_logic_remove_pass.cc" | "graph/passes/switch_logic_remove_pass.cc" | ||||
"graph/passes/switch_data_edges_bypass.cc" | "graph/passes/switch_data_edges_bypass.cc" | ||||
"graph/passes/merge_pass.cc" | "graph/passes/merge_pass.cc" | ||||
"graph/passes/variable_format_pass.cc" | |||||
"graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
"graph/passes/cast_remove_pass.cc" | "graph/passes/cast_remove_pass.cc" | ||||
"graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
@@ -122,12 +122,10 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/dimension_adjust_pass.cc \ | graph/passes/dimension_adjust_pass.cc \ | ||||
graph/passes/get_original_format_pass.cc \ | graph/passes/get_original_format_pass.cc \ | ||||
graph/passes/shape_operate_op_remove_pass.cc \ | graph/passes/shape_operate_op_remove_pass.cc \ | ||||
graph/passes/unused_op_remove_pass.cc \ | |||||
graph/passes/assert_pass.cc \ | graph/passes/assert_pass.cc \ | ||||
graph/passes/dropout_pass.cc \ | graph/passes/dropout_pass.cc \ | ||||
graph/passes/infershape_pass.cc \ | graph/passes/infershape_pass.cc \ | ||||
graph/passes/unused_const_pass.cc \ | graph/passes/unused_const_pass.cc \ | ||||
graph/passes/isolated_op_remove_pass.cc \ | |||||
graph/passes/permute_pass.cc \ | graph/passes/permute_pass.cc \ | ||||
graph/passes/ctrl_edge_transfer_pass.cc \ | graph/passes/ctrl_edge_transfer_pass.cc \ | ||||
graph/passes/end_of_sequence_add_control_pass.cc \ | graph/passes/end_of_sequence_add_control_pass.cc \ | ||||
@@ -209,7 +207,6 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/switch_logic_remove_pass.cc \ | graph/passes/switch_logic_remove_pass.cc \ | ||||
graph/passes/switch_data_edges_bypass.cc \ | graph/passes/switch_data_edges_bypass.cc \ | ||||
graph/passes/merge_pass.cc \ | graph/passes/merge_pass.cc \ | ||||
graph/passes/variable_format_pass.cc \ | |||||
graph/passes/variable_op_pass.cc \ | graph/passes/variable_op_pass.cc \ | ||||
graph/passes/cast_remove_pass.cc \ | graph/passes/cast_remove_pass.cc \ | ||||
graph/passes/transpose_transdata_pass.cc \ | graph/passes/transpose_transdata_pass.cc \ | ||||
@@ -187,7 +187,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/passes/identity_pass.cc \ | graph/passes/identity_pass.cc \ | ||||
graph/passes/ref_identity_delete_op_pass.cc \ | graph/passes/ref_identity_delete_op_pass.cc \ | ||||
graph/passes/infershape_pass.cc \ | graph/passes/infershape_pass.cc \ | ||||
graph/passes/isolated_op_remove_pass.cc \ | |||||
graph/passes/iterator_op_pass.cc \ | graph/passes/iterator_op_pass.cc \ | ||||
graph/passes/link_gen_mask_nodes_pass.cc \ | graph/passes/link_gen_mask_nodes_pass.cc \ | ||||
graph/passes/merge_pass.cc \ | graph/passes/merge_pass.cc \ | ||||
@@ -233,13 +232,11 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/passes/transop_without_reshape_fusion_pass.cc \ | graph/passes/transop_without_reshape_fusion_pass.cc \ | ||||
graph/passes/transpose_transdata_pass.cc \ | graph/passes/transpose_transdata_pass.cc \ | ||||
graph/passes/unused_const_pass.cc \ | graph/passes/unused_const_pass.cc \ | ||||
graph/passes/unused_op_remove_pass.cc \ | |||||
graph/passes/var_is_initialized_op_pass.cc \ | graph/passes/var_is_initialized_op_pass.cc \ | ||||
graph/passes/parallel_concat_start_op_pass.cc \ | graph/passes/parallel_concat_start_op_pass.cc \ | ||||
graph/passes/cond_pass.cc \ | graph/passes/cond_pass.cc \ | ||||
graph/passes/cond_remove_pass.cc \ | graph/passes/cond_remove_pass.cc \ | ||||
graph/passes/for_pass.cc \ | graph/passes/for_pass.cc \ | ||||
graph/passes/variable_format_pass.cc \ | |||||
graph/passes/variable_op_pass.cc \ | graph/passes/variable_op_pass.cc \ | ||||
graph/passes/variable_prepare_op_pass.cc \ | graph/passes/variable_prepare_op_pass.cc \ | ||||
graph/passes/variable_ref_delete_op_pass.cc \ | graph/passes/variable_ref_delete_op_pass.cc \ | ||||
@@ -1,37 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/isolated_op_remove_pass.h" | |||||
#include "common/debug/log.h" | |||||
#include "common/types.h" | |||||
#include "common/util.h" | |||||
namespace ge { | |||||
Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
for (NodePtr &node_ptr : graph->GetDirectNode()) { | |||||
GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, continue); | |||||
if (node_ptr->GetInDataNodes().size() == 0 && node_ptr->GetOutAllNodes().size() == 0 && | |||||
!(node_ptr->GetOpDesc()->HasAttr(TO_BE_OUTPUT))) { | |||||
GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node_ptr), "remove graph node [%s] fail", | |||||
node_ptr->GetOpDesc()->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -1,28 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||||
#define GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||||
#include "inc/graph_pass.h" | |||||
namespace ge { | |||||
class IsolatedOpRemovePass : public GraphPass { | |||||
public: | |||||
Status Run(ge::ComputeGraphPtr graph); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ |
@@ -1,47 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "remove_nodes_pass.h" | |||||
#include "debug/ge_log.h" | |||||
#include "inc/framework/common/util.h" | |||||
#include "inc/graph/utils/node_utils.h" | |||||
namespace ge { | |||||
Status RemoveNodesPass::Run(NodePtr &node) { | |||||
GE_CHECK_NOTNULL(node); | |||||
auto node_type = NodeUtils::GetNodeType(*node); | |||||
auto type_iter = remove_node_types_to_arg_.find(node_type); | |||||
if (type_iter != remove_node_types_to_arg_.end()) { | |||||
GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str()); | |||||
return IsolateAndDeleteNode(node, type_iter->second); | |||||
} | |||||
for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) { | |||||
if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) { | |||||
GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str()); | |||||
return IsolateAndDeleteNode(node, attr_name_to_arg.second); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list<int> arg) { | |||||
remove_node_types_to_arg_[node_type] = std::move(arg); | |||||
return *this; | |||||
} | |||||
RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list<int> arg) { | |||||
remove_node_attr_names_to_arg_[attr_name] = std::move(arg); | |||||
return *this; | |||||
} | |||||
} // namespace ge |
@@ -1,32 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_REMOVE_NODES_PASS_H_ | |||||
#define GE_REMOVE_NODES_PASS_H_ | |||||
#include "graph/passes/base_pass.h" | |||||
namespace ge { | |||||
class RemoveNodesPass : public BaseNodePass { | |||||
public: | |||||
Status Run(NodePtr &node) override; | |||||
RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list<int> arg = {0}); | |||||
RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list<int> arg = {0}); | |||||
private: | |||||
std::map<std::string, std::initializer_list<int>> remove_node_types_to_arg_; | |||||
std::map<std::string, std::initializer_list<int>> remove_node_attr_names_to_arg_; | |||||
}; | |||||
} // namespace ge | |||||
#endif //GE_REMOVE_NODES_PASS_H_ |
@@ -1,134 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/unused_op_remove_pass.h" | |||||
#include <queue> | |||||
#include <set> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/debug/log.h" | |||||
#include "common/op/ge_op_utils.h" | |||||
#include "common/types.h" | |||||
#include "common/util.h" | |||||
#include "graph/utils/attr_utils.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "inc/pass_manager.h" | |||||
#include "graph/passes/isolated_op_remove_pass.h" | |||||
using domi::SUCCESS; | |||||
namespace ge { | |||||
const std::set<std::string> kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; | |||||
const std::set<std::string> kOtherRemoveOpSet = {DROPOUT}; | |||||
Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
std::set<std::string> remove_op_set; | |||||
vector<NodePtr> nodes_to_be_deleted; | |||||
if (fmktype_ == TENSORFLOW) { | |||||
remove_op_set = kRemoveOpSet; | |||||
} else { | |||||
remove_op_set = kOtherRemoveOpSet; | |||||
} | |||||
for (auto &node : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
std::string op_type_str = node->GetOpDesc()->GetType(); | |||||
if (remove_op_set.count(op_type_str)) { | |||||
if (IsExceptions(node)) { | |||||
continue; | |||||
} | |||||
for (auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
NodePtr dst_node = in_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||||
int dst_index = in_anchor->GetIdx(); | |||||
std::vector<bool> list_bool; | |||||
GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||||
list_bool = dst_node->GetOpDesc()->GetIsInputConst(); | |||||
GE_IF_BOOL_EXEC(list_bool.size() == 0, continue); | |||||
list_bool.erase(list_bool.begin() + dst_index); | |||||
dst_node->GetOpDesc()->SetIsInputConst(list_bool); | |||||
} | |||||
} | |||||
if (op_type_str == ASSERT) { | |||||
GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed"); | |||||
} else { | |||||
GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed"); | |||||
} | |||||
} | |||||
} | |||||
for (auto &node : nodes_to_be_deleted) { | |||||
for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) { | |||||
inAnchor->UnlinkAll(); | |||||
} | |||||
for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) { | |||||
outAnchorPtr->UnlinkAll(); | |||||
} | |||||
if (node->GetOutControlAnchor() != nullptr) { | |||||
node->GetOutControlAnchor()->UnlinkAll(); | |||||
} | |||||
GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
vector<NodePtr> &node_vec) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
GE_CHECK_NOTNULL(node); | |||||
node_vec.push_back(node); | |||||
std::queue<NodePtr> node_queue; | |||||
for (auto &src_node : node->GetInDataNodes()) { | |||||
if (src_node->GetOutDataNodesSize() == 1) { | |||||
node_queue.push(src_node); | |||||
} | |||||
} | |||||
while (!node_queue.empty()) { | |||||
NodePtr temp = node_queue.front(); | |||||
node_queue.pop(); | |||||
for (auto &src_node : temp->GetInDataNodes()) { | |||||
if (src_node->GetOutDataNodesSize() == 1) { | |||||
node_queue.push(src_node); | |||||
} | |||||
} | |||||
node_vec.push_back(temp); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) { | |||||
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); | |||||
auto op_def = node->GetOpDesc(); | |||||
GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr"); | |||||
// permute optimised in permute_pass.cpp | |||||
if (op_def->GetType() == PERMUTE) { | |||||
GE_IF_BOOL_EXEC( | |||||
(node->GetInDataNodes().size() != 0 && | |||||
(node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr && | |||||
node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)), | |||||
return false); | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
} // namespace ge |
@@ -1,41 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||||
#define GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||||
#include <string> | |||||
#include <vector> | |||||
#include "framework/common/ge_types.h" | |||||
#include "inc/graph_pass.h" | |||||
namespace ge { | |||||
class UnusedOpRemovePass : public GraphPass { | |||||
public: | |||||
explicit UnusedOpRemovePass(FrameworkType type) : fmktype_(type) {} | |||||
~UnusedOpRemovePass() {} | |||||
Status Run(ge::ComputeGraphPtr graph) override; | |||||
bool IsExceptions(const ge::NodePtr &node); | |||||
private: | |||||
Status CollectParentNode(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node, | |||||
std::vector<ge::NodePtr> &node_vec); | |||||
std::vector<std::string> v_remove_ops; | |||||
FrameworkType fmktype_; | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ |
@@ -1,119 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/variable_format_pass.h" | |||||
#include <map> | |||||
#include <set> | |||||
#include <string> | |||||
#include "framework/common/debug/ge_log.h" | |||||
namespace ge { | |||||
Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
for (auto &node : graph->GetDirectNode()) { | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); | |||||
ge::NodePtr use_node = nullptr; | |||||
if (GetApplyMomentumOpByVariableInput(node, use_node)) { | |||||
GE_CHK_STATUS_RET(UpdateVariableOutFormat(node, use_node), "update variable out format failed"); | |||||
GE_CHK_STATUS_RET(UpdateApplyMomentumInputFormat(use_node), "update apply momentum input format failed"); | |||||
} | |||||
} | |||||
return domi::SUCCESS; | |||||
} | |||||
bool VariableFormatPass::GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||||
GE_IF_BOOL_EXEC(var_node == nullptr, return false); | |||||
std::map<std::string, std::set<int>> confirm_ops = {{"ApplyMomentum", {1}}}; | |||||
for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { | |||||
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
GE_IF_BOOL_EXEC(ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node), return true); | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
bool VariableFormatPass::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||||
const map<string, std::set<int>> &confirm_ops, | |||||
ge::NodePtr &use_node) { | |||||
GE_IF_BOOL_EXEC(in_anchor == nullptr, return false); | |||||
ge::NodePtr dst_node = in_anchor->GetOwnerNode(); | |||||
ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); | |||||
GE_IF_BOOL_EXEC(dst_op_desc == nullptr, return false); | |||||
const string &dst_type = dst_op_desc->GetType(); | |||||
int input_index = in_anchor->GetIdx(); | |||||
GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), | |||||
dst_type.c_str(), input_index); | |||||
GE_IF_BOOL_EXEC(confirm_ops.count(dst_type) > 0, | |||||
GE_IF_BOOL_EXEC(confirm_ops.at(dst_type).count(input_index) > 0, use_node = dst_node; return true);); | |||||
return false; | |||||
} | |||||
Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||||
GE_CHECK_NOTNULL(var_node); | |||||
GE_CHECK_NOTNULL(use_node); | |||||
ge::OpDescPtr op_desc_ptr = use_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc_ptr); | |||||
GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)); | |||||
GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||||
NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
if (in_node != nullptr) { | |||||
string in_op_type = in_node->GetType(); | |||||
if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && | |||||
(in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { | |||||
ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||||
ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); | |||||
if (cur_op_desc_ptr != nullptr) { | |||||
cur_op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||||
cur_op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||||
} | |||||
} | |||||
} | |||||
return domi::SUCCESS; | |||||
} | |||||
Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &node) { | |||||
GE_CHECK_NOTNULL(node); | |||||
ge::OpDescPtr op_desc_ptr = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc_ptr); | |||||
GE_CHECK_NOTNULL(node->GetInDataAnchor(0)); | |||||
GE_CHECK_NOTNULL(node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||||
GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(0)); | |||||
GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(1)); | |||||
GE_CHECK_NOTNULL(op_desc_ptr->MutableOutputDesc(0)); | |||||
NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
if (in_node != nullptr) { | |||||
string in_op_type = in_node->GetType(); | |||||
if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { | |||||
ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||||
op_desc_ptr->MutableInputDesc(0)->SetFormat(format); | |||||
op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); | |||||
op_desc_ptr->MutableInputDesc(1)->SetFormat(format); | |||||
op_desc_ptr->MutableInputDesc(1)->SetOriginFormat(format); | |||||
op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||||
op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||||
} | |||||
} | |||||
return domi::SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -1,44 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||||
#define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||||
#include <map> | |||||
#include <set> | |||||
#include <string> | |||||
#include "graph/types.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "inc/graph_pass.h" | |||||
namespace ge { | |||||
class VariableFormatPass : public GraphPass { | |||||
public: | |||||
Status Run(ge::ComputeGraphPtr graph) override; | |||||
private: | |||||
bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||||
bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||||
const map<string, std::set<int> > &confirm_ops, ge::NodePtr &use_node); | |||||
Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node); | |||||
Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ |
@@ -216,12 +216,10 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" | ||||
@@ -263,7 +261,6 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/variable_format_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | ||||
@@ -495,8 +492,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" | ||||
@@ -670,8 +665,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/permute_pass_unittest.cc" | "graph/passes/permute_pass_unittest.cc" | ||||
"graph/passes/print_op_pass_unittest.cc" | "graph/passes/print_op_pass_unittest.cc" | ||||
"graph/passes/shape_operate_op_remove_pass_unittest.cc" | "graph/passes/shape_operate_op_remove_pass_unittest.cc" | ||||
"graph/passes/unused_and_isolated_op_remove_pass_unittest.cc" | |||||
"graph/passes/variable_op_pass_unittest.cc" | |||||
"graph/passes/variable_op_pass_unittest.cc" | |||||
"graph/passes/base_pass_unittest.cc" | "graph/passes/base_pass_unittest.cc" | ||||
"graph/passes/addn_pass_unittest.cc" | "graph/passes/addn_pass_unittest.cc" | ||||
"graph/passes/save_pass_unittest.cc" | "graph/passes/save_pass_unittest.cc" | ||||
@@ -1,191 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/unused_op_remove_pass.h" | |||||
#include <gtest/gtest.h> | |||||
#include "graph/passes/isolated_op_remove_pass.h" | |||||
#include "pass_manager.h" | |||||
using namespace ge; | |||||
class UtestGraphPassesUnusedAndIsolatedOpRemovePass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
NodePtr AddNode(ComputeGraphPtr graph, const string &name, const string &type, int32_t in_anchors_num = 1, | |||||
int32_t out_anchors_num = 1) { | |||||
GeTensorDesc tensor_desc; | |||||
OpDescPtr op_desc = make_shared<OpDesc>(name, type); | |||||
for (int32_t i = 0; i < in_anchors_num; i++) { | |||||
op_desc->AddInputDesc(tensor_desc); | |||||
} | |||||
for (int32_t i = 0; i < out_anchors_num; i++) { | |||||
op_desc->AddOutputDesc(tensor_desc); | |||||
} | |||||
NodePtr node = graph->AddNode(op_desc); | |||||
return node; | |||||
} | |||||
}; | |||||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||||
NodePtr reshape_node = AddNode(graph, "reshape1", RESHAPE); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | |||||
EXPECT_EQ(SUCCESS, status); | |||||
NodePtr found_node = graph->FindNode("transpose1"); | |||||
EXPECT_EQ(transpose_node, found_node); | |||||
} | |||||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||||
NodePtr squeeze_node = AddNode(graph, "squeeze1", SQUEEZE); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | |||||
EXPECT_EQ(SUCCESS, status); | |||||
NodePtr found_node = graph->FindNode("transpose1"); | |||||
EXPECT_EQ(transpose_node, found_node); | |||||
} | |||||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||||
vector<int64_t> order_list = {0, 2, 3, 1}; | |||||
AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); | |||||
AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); | |||||
NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | |||||
EXPECT_EQ(SUCCESS, status); | |||||
NodePtr found_node0 = graph->FindNode("transpose1"); | |||||
NodePtr found_node = graph->FindNode("conv1"); | |||||
EXPECT_EQ(conv_node, found_node); | |||||
} | |||||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||||
vector<int64_t> order_list = {0, 1, 3, 2}; | |||||
AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); | |||||
AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); | |||||
NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | |||||
EXPECT_EQ(SUCCESS, status); | |||||
NodePtr found_node0 = graph->FindNode("transpose1"); | |||||
EXPECT_EQ(transpose_node, found_node0); | |||||
NodePtr found_node = graph->FindNode("conv1"); | |||||
EXPECT_EQ(conv_node, found_node); | |||||
} | |||||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||||
NodePtr conv3_node = AddNode(graph, "cast3", CAST); | |||||
NodePtr transpose_node = AddNode(graph, "cast1", CAST); | |||||
NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | |||||
EXPECT_EQ(SUCCESS, status); | |||||
} | |||||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
vector<NodePtr> node_vec; | |||||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||||
NodePtr conv3_node = AddNode(graph, "cast3", CAST); | |||||
NodePtr transpose_node = AddNode(graph, "cast1", CAST); | |||||
NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | |||||
EXPECT_EQ(SUCCESS, status); | |||||
} |