diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 6ff9f5d9..b0aa082a 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -270,7 +270,6 @@ set(TRAIN_SRC_LIST "graph/passes/identity_pass.cc" "graph/passes/ref_identity_delete_op_pass.cc" "graph/passes/infershape_pass.cc" - "graph/passes/isolated_op_remove_pass.cc" "graph/passes/iterator_op_pass.cc" "graph/passes/link_gen_mask_nodes_pass.cc" "graph/passes/merge_pass.cc" @@ -317,13 +316,11 @@ set(TRAIN_SRC_LIST "graph/passes/transop_without_reshape_fusion_pass.cc" "graph/passes/transpose_transdata_pass.cc" "graph/passes/unused_const_pass.cc" - "graph/passes/unused_op_remove_pass.cc" "graph/passes/var_is_initialized_op_pass.cc" "graph/passes/parallel_concat_start_op_pass.cc" "graph/passes/cond_pass.cc" "graph/passes/cond_remove_pass.cc" "graph/passes/for_pass.cc" - "graph/passes/variable_format_pass.cc" "graph/passes/variable_op_pass.cc" "graph/passes/variable_prepare_op_pass.cc" "graph/passes/variable_ref_delete_op_pass.cc" @@ -522,12 +519,10 @@ set(INFER_SRC_LIST "graph/passes/dimension_adjust_pass.cc" "graph/passes/get_original_format_pass.cc" "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/unused_op_remove_pass.cc" "graph/passes/assert_pass.cc" "graph/passes/dropout_pass.cc" "graph/passes/infershape_pass.cc" "graph/passes/unused_const_pass.cc" - "graph/passes/isolated_op_remove_pass.cc" "graph/passes/permute_pass.cc" "graph/passes/ctrl_edge_transfer_pass.cc" "graph/passes/end_of_sequence_add_control_pass.cc" @@ -610,7 +605,6 @@ set(INFER_SRC_LIST "graph/passes/switch_logic_remove_pass.cc" "graph/passes/switch_data_edges_bypass.cc" "graph/passes/merge_pass.cc" - "graph/passes/variable_format_pass.cc" "graph/passes/variable_op_pass.cc" "graph/passes/cast_remove_pass.cc" "graph/passes/transpose_transdata_pass.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index ae1288f5..a56eaadf 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -122,12 +122,10 @@ OMG_HOST_SRC_FILES := \ graph/passes/dimension_adjust_pass.cc \ graph/passes/get_original_format_pass.cc \ graph/passes/shape_operate_op_remove_pass.cc \ - graph/passes/unused_op_remove_pass.cc \ graph/passes/assert_pass.cc \ graph/passes/dropout_pass.cc \ graph/passes/infershape_pass.cc \ graph/passes/unused_const_pass.cc \ - graph/passes/isolated_op_remove_pass.cc \ graph/passes/permute_pass.cc \ graph/passes/ctrl_edge_transfer_pass.cc \ graph/passes/end_of_sequence_add_control_pass.cc \ @@ -209,7 +207,6 @@ OMG_HOST_SRC_FILES := \ graph/passes/switch_logic_remove_pass.cc \ graph/passes/switch_data_edges_bypass.cc \ graph/passes/merge_pass.cc \ - graph/passes/variable_format_pass.cc \ graph/passes/variable_op_pass.cc \ graph/passes/cast_remove_pass.cc \ graph/passes/transpose_transdata_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 2aa19e7a..8ca8572c 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -187,7 +187,6 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/identity_pass.cc \ graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/infershape_pass.cc \ - graph/passes/isolated_op_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ graph/passes/link_gen_mask_nodes_pass.cc \ graph/passes/merge_pass.cc \ @@ -233,13 +232,11 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/transop_without_reshape_fusion_pass.cc \ graph/passes/transpose_transdata_pass.cc \ graph/passes/unused_const_pass.cc \ - graph/passes/unused_op_remove_pass.cc \ graph/passes/var_is_initialized_op_pass.cc \ graph/passes/parallel_concat_start_op_pass.cc \ graph/passes/cond_pass.cc \ graph/passes/cond_remove_pass.cc \ graph/passes/for_pass.cc \ - graph/passes/variable_format_pass.cc \ graph/passes/variable_op_pass.cc \ graph/passes/variable_prepare_op_pass.cc \ graph/passes/variable_ref_delete_op_pass.cc \ diff --git a/ge/graph/passes/isolated_op_remove_pass.cc b/ge/graph/passes/isolated_op_remove_pass.cc deleted file mode 100644 index 5c9093e9..00000000 --- a/ge/graph/passes/isolated_op_remove_pass.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/isolated_op_remove_pass.h" - -#include "common/debug/log.h" -#include "common/types.h" -#include "common/util.h" - -namespace ge { -Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { - GE_CHECK_NOTNULL(graph); - for (NodePtr &node_ptr : graph->GetDirectNode()) { - GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, continue); - if (node_ptr->GetInDataNodes().size() == 0 && node_ptr->GetOutAllNodes().size() == 0 && - !(node_ptr->GetOpDesc()->HasAttr(TO_BE_OUTPUT))) { - GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node_ptr), "remove graph node [%s] fail", - node_ptr->GetOpDesc()->GetName().c_str()); - } - } - - return SUCCESS; -} -} // namespace ge diff --git a/ge/graph/passes/isolated_op_remove_pass.h b/ge/graph/passes/isolated_op_remove_pass.h deleted file mode 100755 index 3b7fe7d1..00000000 --- a/ge/graph/passes/isolated_op_remove_pass.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ -#define GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ - -#include "inc/graph_pass.h" - -namespace ge { -class IsolatedOpRemovePass : public GraphPass { - public: - Status Run(ge::ComputeGraphPtr graph); -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ diff --git a/ge/graph/passes/remove_nodes_pass.cc b/ge/graph/passes/remove_nodes_pass.cc deleted file mode 100644 index c238f003..00000000 --- a/ge/graph/passes/remove_nodes_pass.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "remove_nodes_pass.h" -#include "debug/ge_log.h" -#include "inc/framework/common/util.h" -#include "inc/graph/utils/node_utils.h" - -namespace ge { -Status RemoveNodesPass::Run(NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto node_type = NodeUtils::GetNodeType(*node); - auto type_iter = remove_node_types_to_arg_.find(node_type); - if (type_iter != remove_node_types_to_arg_.end()) { - GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str()); - return IsolateAndDeleteNode(node, type_iter->second); - } - for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) { - if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) { - GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str()); - return IsolateAndDeleteNode(node, attr_name_to_arg.second); - } - } - - return SUCCESS; -} -RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list arg) { - remove_node_types_to_arg_[node_type] = std::move(arg); - return *this; -} -RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list arg) { - remove_node_attr_names_to_arg_[attr_name] = std::move(arg); - return *this; -} -} // namespace ge \ No newline at end of file diff --git a/ge/graph/passes/remove_nodes_pass.h b/ge/graph/passes/remove_nodes_pass.h deleted file mode 100644 index 1d4fced9..00000000 --- a/ge/graph/passes/remove_nodes_pass.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef GE_REMOVE_NODES_PASS_H_ -#define GE_REMOVE_NODES_PASS_H_ -#include "graph/passes/base_pass.h" - -namespace ge { -class RemoveNodesPass : public BaseNodePass { - public: - Status Run(NodePtr &node) override; - RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list arg = {0}); - RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list arg = {0}); - - private: - std::map> remove_node_types_to_arg_; - std::map> remove_node_attr_names_to_arg_; -}; -} // namespace ge -#endif //GE_REMOVE_NODES_PASS_H_ diff --git a/ge/graph/passes/unused_op_remove_pass.cc b/ge/graph/passes/unused_op_remove_pass.cc deleted file mode 100644 index 41f7c828..00000000 --- a/ge/graph/passes/unused_op_remove_pass.cc +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/unused_op_remove_pass.h" -#include -#include -#include -#include -#include "common/debug/log.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "inc/pass_manager.h" -#include "graph/passes/isolated_op_remove_pass.h" - -using domi::SUCCESS; - -namespace ge { -const std::set kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; -const std::set kOtherRemoveOpSet = {DROPOUT}; - -Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) { - GE_CHECK_NOTNULL(graph); - std::set remove_op_set; - vector nodes_to_be_deleted; - if (fmktype_ == TENSORFLOW) { - remove_op_set = kRemoveOpSet; - } else { - remove_op_set = kOtherRemoveOpSet; - } - - for (auto &node : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - std::string op_type_str = node->GetOpDesc()->GetType(); - if (remove_op_set.count(op_type_str)) { - if (IsExceptions(node)) { - continue; - } - for (auto &out_anchor : node->GetAllOutDataAnchors()) { - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - NodePtr dst_node = in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(dst_node->GetOpDesc()); - int dst_index = in_anchor->GetIdx(); - std::vector list_bool; - GE_CHECK_NOTNULL(dst_node->GetOpDesc()); - list_bool = dst_node->GetOpDesc()->GetIsInputConst(); - GE_IF_BOOL_EXEC(list_bool.size() == 0, continue); - list_bool.erase(list_bool.begin() + dst_index); - dst_node->GetOpDesc()->SetIsInputConst(list_bool); - } - } - if (op_type_str == ASSERT) { - GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed"); - } else { - GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed"); - } - } - } - for (auto &node : nodes_to_be_deleted) { - for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) { - inAnchor->UnlinkAll(); - } - for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) { - outAnchorPtr->UnlinkAll(); - } - if (node->GetOutControlAnchor() != nullptr) { - node->GetOutControlAnchor()->UnlinkAll(); - } - GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str()); - } - - return SUCCESS; -} - -Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node, - vector &node_vec) { - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(node); - node_vec.push_back(node); - std::queue node_queue; - - for (auto &src_node : node->GetInDataNodes()) { - if (src_node->GetOutDataNodesSize() == 1) { - node_queue.push(src_node); - } - } - - while (!node_queue.empty()) { - NodePtr temp = node_queue.front(); - node_queue.pop(); - - for (auto &src_node : temp->GetInDataNodes()) { - if (src_node->GetOutDataNodesSize() == 1) { - node_queue.push(src_node); - } - } - node_vec.push_back(temp); - } - - return SUCCESS; -} - -bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); - auto op_def = node->GetOpDesc(); - GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr"); - // permute optimised in permute_pass.cpp - if (op_def->GetType() == PERMUTE) { - GE_IF_BOOL_EXEC( - (node->GetInDataNodes().size() != 0 && - (node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr && - node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)), - return false); - return true; - } - return false; -} -} // namespace ge diff --git a/ge/graph/passes/unused_op_remove_pass.h b/ge/graph/passes/unused_op_remove_pass.h deleted file mode 100755 index b9429cfd..00000000 --- a/ge/graph/passes/unused_op_remove_pass.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ -#define GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ - -#include -#include -#include "framework/common/ge_types.h" -#include "inc/graph_pass.h" - -namespace ge { -class UnusedOpRemovePass : public GraphPass { - public: - explicit UnusedOpRemovePass(FrameworkType type) : fmktype_(type) {} - ~UnusedOpRemovePass() {} - Status Run(ge::ComputeGraphPtr graph) override; - bool IsExceptions(const ge::NodePtr &node); - - private: - Status CollectParentNode(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node, - std::vector &node_vec); - std::vector v_remove_ops; - FrameworkType fmktype_; -}; -} // namespace ge - -#endif // GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ diff --git a/ge/graph/passes/variable_format_pass.cc b/ge/graph/passes/variable_format_pass.cc deleted file mode 100644 index bd5300a5..00000000 --- a/ge/graph/passes/variable_format_pass.cc +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/variable_format_pass.h" -#include -#include -#include -#include "framework/common/debug/ge_log.h" - -namespace ge { -Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { - GE_CHECK_NOTNULL(graph); - - for (auto &node : graph->GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); - - ge::NodePtr use_node = nullptr; - if (GetApplyMomentumOpByVariableInput(node, use_node)) { - GE_CHK_STATUS_RET(UpdateVariableOutFormat(node, use_node), "update variable out format failed"); - GE_CHK_STATUS_RET(UpdateApplyMomentumInputFormat(use_node), "update apply momentum input format failed"); - } - } - - return domi::SUCCESS; -} - -bool VariableFormatPass::GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node) { - GE_IF_BOOL_EXEC(var_node == nullptr, return false); - - std::map> confirm_ops = {{"ApplyMomentum", {1}}}; - for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node), return true); - } - } - - return false; -} - -bool VariableFormatPass::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, - const map> &confirm_ops, - ge::NodePtr &use_node) { - GE_IF_BOOL_EXEC(in_anchor == nullptr, return false); - ge::NodePtr dst_node = in_anchor->GetOwnerNode(); - ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); - GE_IF_BOOL_EXEC(dst_op_desc == nullptr, return false); - const string &dst_type = dst_op_desc->GetType(); - int input_index = in_anchor->GetIdx(); - - GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), - dst_type.c_str(), input_index); - - GE_IF_BOOL_EXEC(confirm_ops.count(dst_type) > 0, - GE_IF_BOOL_EXEC(confirm_ops.at(dst_type).count(input_index) > 0, use_node = dst_node; return true);); - return false; -} - -Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node) { - GE_CHECK_NOTNULL(var_node); - GE_CHECK_NOTNULL(use_node); - ge::OpDescPtr op_desc_ptr = use_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc_ptr); - GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)); - GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)->GetPeerOutAnchor()); - NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - if (in_node != nullptr) { - string in_op_type = in_node->GetType(); - if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && - (in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { - ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); - ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); - if (cur_op_desc_ptr != nullptr) { - cur_op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); - cur_op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); - } - } - } - return domi::SUCCESS; -} - -Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &node) { - GE_CHECK_NOTNULL(node); - ge::OpDescPtr op_desc_ptr = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc_ptr); - GE_CHECK_NOTNULL(node->GetInDataAnchor(0)); - GE_CHECK_NOTNULL(node->GetInDataAnchor(0)->GetPeerOutAnchor()); - GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(0)); - GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(1)); - GE_CHECK_NOTNULL(op_desc_ptr->MutableOutputDesc(0)); - NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - if (in_node != nullptr) { - string in_op_type = in_node->GetType(); - if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { - ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); - op_desc_ptr->MutableInputDesc(0)->SetFormat(format); - op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); - op_desc_ptr->MutableInputDesc(1)->SetFormat(format); - op_desc_ptr->MutableInputDesc(1)->SetOriginFormat(format); - op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); - op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); - } - } - return domi::SUCCESS; -} -} // namespace ge diff --git a/ge/graph/passes/variable_format_pass.h b/ge/graph/passes/variable_format_pass.h deleted file mode 100755 index e2c32903..00000000 --- a/ge/graph/passes/variable_format_pass.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ -#define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ - -#include -#include -#include -#include "graph/types.h" -#include "graph/utils/op_desc_utils.h" -#include "inc/graph_pass.h" - -namespace ge { -class VariableFormatPass : public GraphPass { - public: - Status Run(ge::ComputeGraphPtr graph) override; - - private: - bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node); - - bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, - const map > &confirm_ops, ge::NodePtr &use_node); - - Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node); - - Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node); -}; -} // namespace ge - -#endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 7cdec968..895c33df 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -216,12 +216,10 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" @@ -263,7 +261,6 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/variable_format_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" @@ -495,8 +492,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" @@ -670,8 +665,7 @@ set(PASS_TEST_FILES "graph/passes/permute_pass_unittest.cc" "graph/passes/print_op_pass_unittest.cc" "graph/passes/shape_operate_op_remove_pass_unittest.cc" - "graph/passes/unused_and_isolated_op_remove_pass_unittest.cc" - "graph/passes/variable_op_pass_unittest.cc" + "graph/passes/variable_op_pass_unittest.cc" "graph/passes/base_pass_unittest.cc" "graph/passes/addn_pass_unittest.cc" "graph/passes/save_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc deleted file mode 100644 index 21b5d7e3..00000000 --- a/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/unused_op_remove_pass.h" - -#include -#include "graph/passes/isolated_op_remove_pass.h" -#include "pass_manager.h" - -using namespace ge; - -class UtestGraphPassesUnusedAndIsolatedOpRemovePass : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - - NodePtr AddNode(ComputeGraphPtr graph, const string &name, const string &type, int32_t in_anchors_num = 1, - int32_t out_anchors_num = 1) { - GeTensorDesc tensor_desc; - OpDescPtr op_desc = make_shared(name, type); - for (int32_t i = 0; i < in_anchors_num; i++) { - op_desc->AddInputDesc(tensor_desc); - } - for (int32_t i = 0; i < out_anchors_num; i++) { - op_desc->AddOutputDesc(tensor_desc); - } - - NodePtr node = graph->AddNode(op_desc); - return node; - } -}; - -TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) { - ComputeGraphPtr graph = std::make_shared("test"); - - NodePtr data_node = AddNode(graph, "DATA", DATA); - NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); - NodePtr reshape_node = AddNode(graph, "reshape1", RESHAPE); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); - - ge::UnusedOpRemovePass unused_pass(TENSORFLOW); - ge::IsolatedOpRemovePass isolate_pass; - std::vector> passes; - passes.emplace_back("", &isolate_pass); - passes.emplace_back("", &unused_pass); - Status status = PassManager::Run(graph, passes); - EXPECT_EQ(SUCCESS, status); - NodePtr found_node = graph->FindNode("transpose1"); - EXPECT_EQ(transpose_node, found_node); -} - -TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) { - ComputeGraphPtr graph = std::make_shared("test"); - - NodePtr data_node = AddNode(graph, "DATA", DATA); - NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); - NodePtr squeeze_node = AddNode(graph, "squeeze1", SQUEEZE); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); - - ge::UnusedOpRemovePass unused_pass(TENSORFLOW); - ge::IsolatedOpRemovePass isolate_pass; - std::vector> passes; - passes.emplace_back("", &isolate_pass); - passes.emplace_back("", &unused_pass); - Status status = PassManager::Run(graph, passes); - EXPECT_EQ(SUCCESS, status); - NodePtr found_node = graph->FindNode("transpose1"); - EXPECT_EQ(transpose_node, found_node); -} - -TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { - ComputeGraphPtr graph = std::make_shared("test"); - - NodePtr data_node = AddNode(graph, "DATA", DATA); - - NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); - vector order_list = {0, 2, 3, 1}; - AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); - AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); - - NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); - - NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); - - ge::UnusedOpRemovePass unused_pass(TENSORFLOW); - ge::IsolatedOpRemovePass isolate_pass; - std::vector> passes; - passes.emplace_back("", &isolate_pass); - passes.emplace_back("", &unused_pass); - Status status = PassManager::Run(graph, passes); - EXPECT_EQ(SUCCESS, status); - NodePtr found_node0 = graph->FindNode("transpose1"); - NodePtr found_node = graph->FindNode("conv1"); - EXPECT_EQ(conv_node, found_node); -} - -TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) { - ComputeGraphPtr graph = std::make_shared("test"); - - NodePtr data_node = AddNode(graph, "DATA", DATA); - - NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); - vector order_list = {0, 1, 3, 2}; - AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); - AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); - - NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); - - NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); - - ge::UnusedOpRemovePass unused_pass(TENSORFLOW); - ge::IsolatedOpRemovePass isolate_pass; - std::vector> passes; - passes.emplace_back("", &isolate_pass); - passes.emplace_back("", &unused_pass); - Status status = PassManager::Run(graph, passes); - EXPECT_EQ(SUCCESS, status); - NodePtr found_node0 = graph->FindNode("transpose1"); - EXPECT_EQ(transpose_node, found_node0); - NodePtr found_node = graph->FindNode("conv1"); - EXPECT_EQ(conv_node, found_node); -} - -TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) { - ComputeGraphPtr graph = std::make_shared("test"); - - NodePtr data_node = AddNode(graph, "DATA", DATA); - NodePtr conv3_node = AddNode(graph, "cast3", CAST); - NodePtr transpose_node = AddNode(graph, "cast1", CAST); - NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); - - ge::UnusedOpRemovePass unused_pass(TENSORFLOW); - ge::IsolatedOpRemovePass isolate_pass; - std::vector> passes; - passes.emplace_back("", &isolate_pass); - passes.emplace_back("", &unused_pass); - Status status = PassManager::Run(graph, passes); - EXPECT_EQ(SUCCESS, status); -} - -TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) { - ComputeGraphPtr graph = std::make_shared("test"); - vector node_vec; - - NodePtr data_node = AddNode(graph, "DATA", DATA); - NodePtr conv3_node = AddNode(graph, "cast3", CAST); - NodePtr transpose_node = AddNode(graph, "cast1", CAST); - NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); - - ge::UnusedOpRemovePass unused_pass(TENSORFLOW); - ge::IsolatedOpRemovePass isolate_pass; - std::vector> passes; - passes.emplace_back("", &isolate_pass); - passes.emplace_back("", &unused_pass); - Status status = PassManager::Run(graph, passes); - EXPECT_EQ(SUCCESS, status); -}