From: @hugo1 Reviewed-by: @sheng-nan,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -270,7 +270,6 @@ set(TRAIN_SRC_LIST | |||
| "graph/passes/identity_pass.cc" | |||
| "graph/passes/ref_identity_delete_op_pass.cc" | |||
| "graph/passes/infershape_pass.cc" | |||
| "graph/passes/isolated_op_remove_pass.cc" | |||
| "graph/passes/iterator_op_pass.cc" | |||
| "graph/passes/link_gen_mask_nodes_pass.cc" | |||
| "graph/passes/merge_pass.cc" | |||
| @@ -317,13 +316,11 @@ set(TRAIN_SRC_LIST | |||
| "graph/passes/transop_without_reshape_fusion_pass.cc" | |||
| "graph/passes/transpose_transdata_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/unused_op_remove_pass.cc" | |||
| "graph/passes/var_is_initialized_op_pass.cc" | |||
| "graph/passes/parallel_concat_start_op_pass.cc" | |||
| "graph/passes/cond_pass.cc" | |||
| "graph/passes/cond_remove_pass.cc" | |||
| "graph/passes/for_pass.cc" | |||
| "graph/passes/variable_format_pass.cc" | |||
| "graph/passes/variable_op_pass.cc" | |||
| "graph/passes/variable_prepare_op_pass.cc" | |||
| "graph/passes/variable_ref_delete_op_pass.cc" | |||
| @@ -522,12 +519,10 @@ set(INFER_SRC_LIST | |||
| "graph/passes/dimension_adjust_pass.cc" | |||
| "graph/passes/get_original_format_pass.cc" | |||
| "graph/passes/shape_operate_op_remove_pass.cc" | |||
| "graph/passes/unused_op_remove_pass.cc" | |||
| "graph/passes/assert_pass.cc" | |||
| "graph/passes/dropout_pass.cc" | |||
| "graph/passes/infershape_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/isolated_op_remove_pass.cc" | |||
| "graph/passes/permute_pass.cc" | |||
| "graph/passes/ctrl_edge_transfer_pass.cc" | |||
| "graph/passes/end_of_sequence_add_control_pass.cc" | |||
| @@ -610,7 +605,6 @@ set(INFER_SRC_LIST | |||
| "graph/passes/switch_logic_remove_pass.cc" | |||
| "graph/passes/switch_data_edges_bypass.cc" | |||
| "graph/passes/merge_pass.cc" | |||
| "graph/passes/variable_format_pass.cc" | |||
| "graph/passes/variable_op_pass.cc" | |||
| "graph/passes/cast_remove_pass.cc" | |||
| "graph/passes/transpose_transdata_pass.cc" | |||
| @@ -122,12 +122,10 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/dimension_adjust_pass.cc \ | |||
| graph/passes/get_original_format_pass.cc \ | |||
| graph/passes/shape_operate_op_remove_pass.cc \ | |||
| graph/passes/unused_op_remove_pass.cc \ | |||
| graph/passes/assert_pass.cc \ | |||
| graph/passes/dropout_pass.cc \ | |||
| graph/passes/infershape_pass.cc \ | |||
| graph/passes/unused_const_pass.cc \ | |||
| graph/passes/isolated_op_remove_pass.cc \ | |||
| graph/passes/permute_pass.cc \ | |||
| graph/passes/ctrl_edge_transfer_pass.cc \ | |||
| graph/passes/end_of_sequence_add_control_pass.cc \ | |||
| @@ -209,7 +207,6 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/switch_logic_remove_pass.cc \ | |||
| graph/passes/switch_data_edges_bypass.cc \ | |||
| graph/passes/merge_pass.cc \ | |||
| graph/passes/variable_format_pass.cc \ | |||
| graph/passes/variable_op_pass.cc \ | |||
| graph/passes/cast_remove_pass.cc \ | |||
| graph/passes/transpose_transdata_pass.cc \ | |||
| @@ -187,7 +187,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/identity_pass.cc \ | |||
| graph/passes/ref_identity_delete_op_pass.cc \ | |||
| graph/passes/infershape_pass.cc \ | |||
| graph/passes/isolated_op_remove_pass.cc \ | |||
| graph/passes/iterator_op_pass.cc \ | |||
| graph/passes/link_gen_mask_nodes_pass.cc \ | |||
| graph/passes/merge_pass.cc \ | |||
| @@ -233,13 +232,11 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/transop_without_reshape_fusion_pass.cc \ | |||
| graph/passes/transpose_transdata_pass.cc \ | |||
| graph/passes/unused_const_pass.cc \ | |||
| graph/passes/unused_op_remove_pass.cc \ | |||
| graph/passes/var_is_initialized_op_pass.cc \ | |||
| graph/passes/parallel_concat_start_op_pass.cc \ | |||
| graph/passes/cond_pass.cc \ | |||
| graph/passes/cond_remove_pass.cc \ | |||
| graph/passes/for_pass.cc \ | |||
| graph/passes/variable_format_pass.cc \ | |||
| graph/passes/variable_op_pass.cc \ | |||
| graph/passes/variable_prepare_op_pass.cc \ | |||
| graph/passes/variable_ref_delete_op_pass.cc \ | |||
| @@ -1,37 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/isolated_op_remove_pass.h" | |||
| #include "common/debug/log.h" | |||
| #include "common/types.h" | |||
| #include "common/util.h" | |||
| namespace ge { | |||
| Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| for (NodePtr &node_ptr : graph->GetDirectNode()) { | |||
| GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, continue); | |||
| if (node_ptr->GetInDataNodes().size() == 0 && node_ptr->GetOutAllNodes().size() == 0 && | |||
| !(node_ptr->GetOpDesc()->HasAttr(TO_BE_OUTPUT))) { | |||
| GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node_ptr), "remove graph node [%s] fail", | |||
| node_ptr->GetOpDesc()->GetName().c_str()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -1,28 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||
| #define GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||
| #include "inc/graph_pass.h" | |||
| namespace ge { | |||
| class IsolatedOpRemovePass : public GraphPass { | |||
| public: | |||
| Status Run(ge::ComputeGraphPtr graph); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||
| @@ -1,47 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "remove_nodes_pass.h" | |||
| #include "debug/ge_log.h" | |||
| #include "inc/framework/common/util.h" | |||
| #include "inc/graph/utils/node_utils.h" | |||
| namespace ge { | |||
| Status RemoveNodesPass::Run(NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| auto node_type = NodeUtils::GetNodeType(*node); | |||
| auto type_iter = remove_node_types_to_arg_.find(node_type); | |||
| if (type_iter != remove_node_types_to_arg_.end()) { | |||
| GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str()); | |||
| return IsolateAndDeleteNode(node, type_iter->second); | |||
| } | |||
| for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) { | |||
| if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) { | |||
| GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str()); | |||
| return IsolateAndDeleteNode(node, attr_name_to_arg.second); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list<int> arg) { | |||
| remove_node_types_to_arg_[node_type] = std::move(arg); | |||
| return *this; | |||
| } | |||
| RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list<int> arg) { | |||
| remove_node_attr_names_to_arg_[attr_name] = std::move(arg); | |||
| return *this; | |||
| } | |||
| } // namespace ge | |||
| @@ -1,32 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_REMOVE_NODES_PASS_H_ | |||
| #define GE_REMOVE_NODES_PASS_H_ | |||
| #include "graph/passes/base_pass.h" | |||
| namespace ge { | |||
| class RemoveNodesPass : public BaseNodePass { | |||
| public: | |||
| Status Run(NodePtr &node) override; | |||
| RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list<int> arg = {0}); | |||
| RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list<int> arg = {0}); | |||
| private: | |||
| std::map<std::string, std::initializer_list<int>> remove_node_types_to_arg_; | |||
| std::map<std::string, std::initializer_list<int>> remove_node_attr_names_to_arg_; | |||
| }; | |||
| } // namespace ge | |||
| #endif //GE_REMOVE_NODES_PASS_H_ | |||
| @@ -1,134 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/unused_op_remove_pass.h" | |||
| #include <queue> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "common/debug/log.h" | |||
| #include "common/op/ge_op_utils.h" | |||
| #include "common/types.h" | |||
| #include "common/util.h" | |||
| #include "graph/utils/attr_utils.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "inc/pass_manager.h" | |||
| #include "graph/passes/isolated_op_remove_pass.h" | |||
| using domi::SUCCESS; | |||
| namespace ge { | |||
| const std::set<std::string> kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; | |||
| const std::set<std::string> kOtherRemoveOpSet = {DROPOUT}; | |||
| Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| std::set<std::string> remove_op_set; | |||
| vector<NodePtr> nodes_to_be_deleted; | |||
| if (fmktype_ == TENSORFLOW) { | |||
| remove_op_set = kRemoveOpSet; | |||
| } else { | |||
| remove_op_set = kOtherRemoveOpSet; | |||
| } | |||
| for (auto &node : graph->GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| std::string op_type_str = node->GetOpDesc()->GetType(); | |||
| if (remove_op_set.count(op_type_str)) { | |||
| if (IsExceptions(node)) { | |||
| continue; | |||
| } | |||
| for (auto &out_anchor : node->GetAllOutDataAnchors()) { | |||
| for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
| NodePtr dst_node = in_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||
| int dst_index = in_anchor->GetIdx(); | |||
| std::vector<bool> list_bool; | |||
| GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||
| list_bool = dst_node->GetOpDesc()->GetIsInputConst(); | |||
| GE_IF_BOOL_EXEC(list_bool.size() == 0, continue); | |||
| list_bool.erase(list_bool.begin() + dst_index); | |||
| dst_node->GetOpDesc()->SetIsInputConst(list_bool); | |||
| } | |||
| } | |||
| if (op_type_str == ASSERT) { | |||
| GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed"); | |||
| } else { | |||
| GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed"); | |||
| } | |||
| } | |||
| } | |||
| for (auto &node : nodes_to_be_deleted) { | |||
| for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) { | |||
| inAnchor->UnlinkAll(); | |||
| } | |||
| for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) { | |||
| outAnchorPtr->UnlinkAll(); | |||
| } | |||
| if (node->GetOutControlAnchor() != nullptr) { | |||
| node->GetOutControlAnchor()->UnlinkAll(); | |||
| } | |||
| GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node, | |||
| vector<NodePtr> &node_vec) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| GE_CHECK_NOTNULL(node); | |||
| node_vec.push_back(node); | |||
| std::queue<NodePtr> node_queue; | |||
| for (auto &src_node : node->GetInDataNodes()) { | |||
| if (src_node->GetOutDataNodesSize() == 1) { | |||
| node_queue.push(src_node); | |||
| } | |||
| } | |||
| while (!node_queue.empty()) { | |||
| NodePtr temp = node_queue.front(); | |||
| node_queue.pop(); | |||
| for (auto &src_node : temp->GetInDataNodes()) { | |||
| if (src_node->GetOutDataNodesSize() == 1) { | |||
| node_queue.push(src_node); | |||
| } | |||
| } | |||
| node_vec.push_back(temp); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) { | |||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); | |||
| auto op_def = node->GetOpDesc(); | |||
| GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr"); | |||
| // permute optimised in permute_pass.cpp | |||
| if (op_def->GetType() == PERMUTE) { | |||
| GE_IF_BOOL_EXEC( | |||
| (node->GetInDataNodes().size() != 0 && | |||
| (node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr && | |||
| node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)), | |||
| return false); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace ge | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||
| #define GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "framework/common/ge_types.h" | |||
| #include "inc/graph_pass.h" | |||
| namespace ge { | |||
| class UnusedOpRemovePass : public GraphPass { | |||
| public: | |||
| explicit UnusedOpRemovePass(FrameworkType type) : fmktype_(type) {} | |||
| ~UnusedOpRemovePass() {} | |||
| Status Run(ge::ComputeGraphPtr graph) override; | |||
| bool IsExceptions(const ge::NodePtr &node); | |||
| private: | |||
| Status CollectParentNode(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node, | |||
| std::vector<ge::NodePtr> &node_vec); | |||
| std::vector<std::string> v_remove_ops; | |||
| FrameworkType fmktype_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||
| @@ -1,119 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/variable_format_pass.h" | |||
| #include <map> | |||
| #include <set> | |||
| #include <string> | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace ge { | |||
| Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| for (auto &node : graph->GetDirectNode()) { | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); | |||
| ge::NodePtr use_node = nullptr; | |||
| if (GetApplyMomentumOpByVariableInput(node, use_node)) { | |||
| GE_CHK_STATUS_RET(UpdateVariableOutFormat(node, use_node), "update variable out format failed"); | |||
| GE_CHK_STATUS_RET(UpdateApplyMomentumInputFormat(use_node), "update apply momentum input format failed"); | |||
| } | |||
| } | |||
| return domi::SUCCESS; | |||
| } | |||
| bool VariableFormatPass::GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||
| GE_IF_BOOL_EXEC(var_node == nullptr, return false); | |||
| std::map<std::string, std::set<int>> confirm_ops = {{"ApplyMomentum", {1}}}; | |||
| for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { | |||
| for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
| GE_IF_BOOL_EXEC(ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node), return true); | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool VariableFormatPass::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||
| const map<string, std::set<int>> &confirm_ops, | |||
| ge::NodePtr &use_node) { | |||
| GE_IF_BOOL_EXEC(in_anchor == nullptr, return false); | |||
| ge::NodePtr dst_node = in_anchor->GetOwnerNode(); | |||
| ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(dst_op_desc == nullptr, return false); | |||
| const string &dst_type = dst_op_desc->GetType(); | |||
| int input_index = in_anchor->GetIdx(); | |||
| GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), | |||
| dst_type.c_str(), input_index); | |||
| GE_IF_BOOL_EXEC(confirm_ops.count(dst_type) > 0, | |||
| GE_IF_BOOL_EXEC(confirm_ops.at(dst_type).count(input_index) > 0, use_node = dst_node; return true);); | |||
| return false; | |||
| } | |||
| Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||
| GE_CHECK_NOTNULL(var_node); | |||
| GE_CHECK_NOTNULL(use_node); | |||
| ge::OpDescPtr op_desc_ptr = use_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc_ptr); | |||
| GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)); | |||
| GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||
| NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
| if (in_node != nullptr) { | |||
| string in_op_type = in_node->GetType(); | |||
| if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && | |||
| (in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { | |||
| ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||
| ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); | |||
| if (cur_op_desc_ptr != nullptr) { | |||
| cur_op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||
| cur_op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||
| } | |||
| } | |||
| } | |||
| return domi::SUCCESS; | |||
| } | |||
| Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| ge::OpDescPtr op_desc_ptr = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc_ptr); | |||
| GE_CHECK_NOTNULL(node->GetInDataAnchor(0)); | |||
| GE_CHECK_NOTNULL(node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||
| GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(0)); | |||
| GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(1)); | |||
| GE_CHECK_NOTNULL(op_desc_ptr->MutableOutputDesc(0)); | |||
| NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
| if (in_node != nullptr) { | |||
| string in_op_type = in_node->GetType(); | |||
| if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { | |||
| ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||
| op_desc_ptr->MutableInputDesc(0)->SetFormat(format); | |||
| op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); | |||
| op_desc_ptr->MutableInputDesc(1)->SetFormat(format); | |||
| op_desc_ptr->MutableInputDesc(1)->SetOriginFormat(format); | |||
| op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||
| op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||
| } | |||
| } | |||
| return domi::SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||
| #define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||
| #include <map> | |||
| #include <set> | |||
| #include <string> | |||
| #include "graph/types.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "inc/graph_pass.h" | |||
| namespace ge { | |||
| class VariableFormatPass : public GraphPass { | |||
| public: | |||
| Status Run(ge::ComputeGraphPtr graph) override; | |||
| private: | |||
| bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||
| bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||
| const map<string, std::set<int> > &confirm_ops, ge::NodePtr &use_node); | |||
| Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node); | |||
| Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||
| @@ -216,12 +216,10 @@ set(COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" | |||
| @@ -263,7 +261,6 @@ set(COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/variable_format_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | |||
| @@ -495,8 +492,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" | |||
| @@ -670,8 +665,7 @@ set(PASS_TEST_FILES | |||
| "graph/passes/permute_pass_unittest.cc" | |||
| "graph/passes/print_op_pass_unittest.cc" | |||
| "graph/passes/shape_operate_op_remove_pass_unittest.cc" | |||
| "graph/passes/unused_and_isolated_op_remove_pass_unittest.cc" | |||
| "graph/passes/variable_op_pass_unittest.cc" | |||
| "graph/passes/variable_op_pass_unittest.cc" | |||
| "graph/passes/base_pass_unittest.cc" | |||
| "graph/passes/addn_pass_unittest.cc" | |||
| "graph/passes/save_pass_unittest.cc" | |||
| @@ -1,191 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/unused_op_remove_pass.h" | |||
| #include <gtest/gtest.h> | |||
| #include "graph/passes/isolated_op_remove_pass.h" | |||
| #include "pass_manager.h" | |||
| using namespace ge; | |||
| class UtestGraphPassesUnusedAndIsolatedOpRemovePass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| NodePtr AddNode(ComputeGraphPtr graph, const string &name, const string &type, int32_t in_anchors_num = 1, | |||
| int32_t out_anchors_num = 1) { | |||
| GeTensorDesc tensor_desc; | |||
| OpDescPtr op_desc = make_shared<OpDesc>(name, type); | |||
| for (int32_t i = 0; i < in_anchors_num; i++) { | |||
| op_desc->AddInputDesc(tensor_desc); | |||
| } | |||
| for (int32_t i = 0; i < out_anchors_num; i++) { | |||
| op_desc->AddOutputDesc(tensor_desc); | |||
| } | |||
| NodePtr node = graph->AddNode(op_desc); | |||
| return node; | |||
| } | |||
| }; | |||
| TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
| NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
| NodePtr reshape_node = AddNode(graph, "reshape1", RESHAPE); | |||
| GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); | |||
| ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
| ge::IsolatedOpRemovePass isolate_pass; | |||
| std::vector<std::pair<string, GraphPass*>> passes; | |||
| passes.emplace_back("", &isolate_pass); | |||
| passes.emplace_back("", &unused_pass); | |||
| Status status = PassManager::Run(graph, passes); | |||
| EXPECT_EQ(SUCCESS, status); | |||
| NodePtr found_node = graph->FindNode("transpose1"); | |||
| EXPECT_EQ(transpose_node, found_node); | |||
| } | |||
| TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
| NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
| NodePtr squeeze_node = AddNode(graph, "squeeze1", SQUEEZE); | |||
| GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); | |||
| ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
| ge::IsolatedOpRemovePass isolate_pass; | |||
| std::vector<std::pair<string, GraphPass*>> passes; | |||
| passes.emplace_back("", &isolate_pass); | |||
| passes.emplace_back("", &unused_pass); | |||
| Status status = PassManager::Run(graph, passes); | |||
| EXPECT_EQ(SUCCESS, status); | |||
| NodePtr found_node = graph->FindNode("transpose1"); | |||
| EXPECT_EQ(transpose_node, found_node); | |||
| } | |||
| TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
| NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
| vector<int64_t> order_list = {0, 2, 3, 1}; | |||
| AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); | |||
| AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); | |||
| NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); | |||
| GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||
| NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||
| GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||
| ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
| ge::IsolatedOpRemovePass isolate_pass; | |||
| std::vector<std::pair<string, GraphPass*>> passes; | |||
| passes.emplace_back("", &isolate_pass); | |||
| passes.emplace_back("", &unused_pass); | |||
| Status status = PassManager::Run(graph, passes); | |||
| EXPECT_EQ(SUCCESS, status); | |||
| NodePtr found_node0 = graph->FindNode("transpose1"); | |||
| NodePtr found_node = graph->FindNode("conv1"); | |||
| EXPECT_EQ(conv_node, found_node); | |||
| } | |||
| TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
| NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
| vector<int64_t> order_list = {0, 1, 3, 2}; | |||
| AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); | |||
| AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); | |||
| NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); | |||
| GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||
| NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||
| GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||
| ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
| ge::IsolatedOpRemovePass isolate_pass; | |||
| std::vector<std::pair<string, GraphPass*>> passes; | |||
| passes.emplace_back("", &isolate_pass); | |||
| passes.emplace_back("", &unused_pass); | |||
| Status status = PassManager::Run(graph, passes); | |||
| EXPECT_EQ(SUCCESS, status); | |||
| NodePtr found_node0 = graph->FindNode("transpose1"); | |||
| EXPECT_EQ(transpose_node, found_node0); | |||
| NodePtr found_node = graph->FindNode("conv1"); | |||
| EXPECT_EQ(conv_node, found_node); | |||
| } | |||
| TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
| NodePtr conv3_node = AddNode(graph, "cast3", CAST); | |||
| NodePtr transpose_node = AddNode(graph, "cast1", CAST); | |||
| NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); | |||
| GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||
| ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
| ge::IsolatedOpRemovePass isolate_pass; | |||
| std::vector<std::pair<string, GraphPass*>> passes; | |||
| passes.emplace_back("", &isolate_pass); | |||
| passes.emplace_back("", &unused_pass); | |||
| Status status = PassManager::Run(graph, passes); | |||
| EXPECT_EQ(SUCCESS, status); | |||
| } | |||
| TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| vector<NodePtr> node_vec; | |||
| NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
| NodePtr conv3_node = AddNode(graph, "cast3", CAST); | |||
| NodePtr transpose_node = AddNode(graph, "cast1", CAST); | |||
| NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); | |||
| GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||
| ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
| ge::IsolatedOpRemovePass isolate_pass; | |||
| std::vector<std::pair<string, GraphPass*>> passes; | |||
| passes.emplace_back("", &isolate_pass); | |||
| passes.emplace_back("", &unused_pass); | |||
| Status status = PassManager::Run(graph, passes); | |||
| EXPECT_EQ(SUCCESS, status); | |||
| } | |||