From: @hugo1 Reviewed-by: @sheng-nan,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -270,7 +270,6 @@ set(TRAIN_SRC_LIST | |||
"graph/passes/identity_pass.cc" | |||
"graph/passes/ref_identity_delete_op_pass.cc" | |||
"graph/passes/infershape_pass.cc" | |||
"graph/passes/isolated_op_remove_pass.cc" | |||
"graph/passes/iterator_op_pass.cc" | |||
"graph/passes/link_gen_mask_nodes_pass.cc" | |||
"graph/passes/merge_pass.cc" | |||
@@ -317,13 +316,11 @@ set(TRAIN_SRC_LIST | |||
"graph/passes/transop_without_reshape_fusion_pass.cc" | |||
"graph/passes/transpose_transdata_pass.cc" | |||
"graph/passes/unused_const_pass.cc" | |||
"graph/passes/unused_op_remove_pass.cc" | |||
"graph/passes/var_is_initialized_op_pass.cc" | |||
"graph/passes/parallel_concat_start_op_pass.cc" | |||
"graph/passes/cond_pass.cc" | |||
"graph/passes/cond_remove_pass.cc" | |||
"graph/passes/for_pass.cc" | |||
"graph/passes/variable_format_pass.cc" | |||
"graph/passes/variable_op_pass.cc" | |||
"graph/passes/variable_prepare_op_pass.cc" | |||
"graph/passes/variable_ref_delete_op_pass.cc" | |||
@@ -522,12 +519,10 @@ set(INFER_SRC_LIST | |||
"graph/passes/dimension_adjust_pass.cc" | |||
"graph/passes/get_original_format_pass.cc" | |||
"graph/passes/shape_operate_op_remove_pass.cc" | |||
"graph/passes/unused_op_remove_pass.cc" | |||
"graph/passes/assert_pass.cc" | |||
"graph/passes/dropout_pass.cc" | |||
"graph/passes/infershape_pass.cc" | |||
"graph/passes/unused_const_pass.cc" | |||
"graph/passes/isolated_op_remove_pass.cc" | |||
"graph/passes/permute_pass.cc" | |||
"graph/passes/ctrl_edge_transfer_pass.cc" | |||
"graph/passes/end_of_sequence_add_control_pass.cc" | |||
@@ -610,7 +605,6 @@ set(INFER_SRC_LIST | |||
"graph/passes/switch_logic_remove_pass.cc" | |||
"graph/passes/switch_data_edges_bypass.cc" | |||
"graph/passes/merge_pass.cc" | |||
"graph/passes/variable_format_pass.cc" | |||
"graph/passes/variable_op_pass.cc" | |||
"graph/passes/cast_remove_pass.cc" | |||
"graph/passes/transpose_transdata_pass.cc" | |||
@@ -122,12 +122,10 @@ OMG_HOST_SRC_FILES := \ | |||
graph/passes/dimension_adjust_pass.cc \ | |||
graph/passes/get_original_format_pass.cc \ | |||
graph/passes/shape_operate_op_remove_pass.cc \ | |||
graph/passes/unused_op_remove_pass.cc \ | |||
graph/passes/assert_pass.cc \ | |||
graph/passes/dropout_pass.cc \ | |||
graph/passes/infershape_pass.cc \ | |||
graph/passes/unused_const_pass.cc \ | |||
graph/passes/isolated_op_remove_pass.cc \ | |||
graph/passes/permute_pass.cc \ | |||
graph/passes/ctrl_edge_transfer_pass.cc \ | |||
graph/passes/end_of_sequence_add_control_pass.cc \ | |||
@@ -209,7 +207,6 @@ OMG_HOST_SRC_FILES := \ | |||
graph/passes/switch_logic_remove_pass.cc \ | |||
graph/passes/switch_data_edges_bypass.cc \ | |||
graph/passes/merge_pass.cc \ | |||
graph/passes/variable_format_pass.cc \ | |||
graph/passes/variable_op_pass.cc \ | |||
graph/passes/cast_remove_pass.cc \ | |||
graph/passes/transpose_transdata_pass.cc \ | |||
@@ -187,7 +187,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
graph/passes/identity_pass.cc \ | |||
graph/passes/ref_identity_delete_op_pass.cc \ | |||
graph/passes/infershape_pass.cc \ | |||
graph/passes/isolated_op_remove_pass.cc \ | |||
graph/passes/iterator_op_pass.cc \ | |||
graph/passes/link_gen_mask_nodes_pass.cc \ | |||
graph/passes/merge_pass.cc \ | |||
@@ -233,13 +232,11 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
graph/passes/transop_without_reshape_fusion_pass.cc \ | |||
graph/passes/transpose_transdata_pass.cc \ | |||
graph/passes/unused_const_pass.cc \ | |||
graph/passes/unused_op_remove_pass.cc \ | |||
graph/passes/var_is_initialized_op_pass.cc \ | |||
graph/passes/parallel_concat_start_op_pass.cc \ | |||
graph/passes/cond_pass.cc \ | |||
graph/passes/cond_remove_pass.cc \ | |||
graph/passes/for_pass.cc \ | |||
graph/passes/variable_format_pass.cc \ | |||
graph/passes/variable_op_pass.cc \ | |||
graph/passes/variable_prepare_op_pass.cc \ | |||
graph/passes/variable_ref_delete_op_pass.cc \ | |||
@@ -1,37 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/isolated_op_remove_pass.h" | |||
#include "common/debug/log.h" | |||
#include "common/types.h" | |||
#include "common/util.h" | |||
namespace ge { | |||
Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { | |||
GE_CHECK_NOTNULL(graph); | |||
for (NodePtr &node_ptr : graph->GetDirectNode()) { | |||
GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, continue); | |||
if (node_ptr->GetInDataNodes().size() == 0 && node_ptr->GetOutAllNodes().size() == 0 && | |||
!(node_ptr->GetOpDesc()->HasAttr(TO_BE_OUTPUT))) { | |||
GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node_ptr), "remove graph node [%s] fail", | |||
node_ptr->GetOpDesc()->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -1,28 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||
#define GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||
#include "inc/graph_pass.h" | |||
namespace ge { | |||
class IsolatedOpRemovePass : public GraphPass { | |||
public: | |||
Status Run(ge::ComputeGraphPtr graph); | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ |
@@ -1,47 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "remove_nodes_pass.h" | |||
#include "debug/ge_log.h" | |||
#include "inc/framework/common/util.h" | |||
#include "inc/graph/utils/node_utils.h" | |||
namespace ge { | |||
Status RemoveNodesPass::Run(NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
auto node_type = NodeUtils::GetNodeType(*node); | |||
auto type_iter = remove_node_types_to_arg_.find(node_type); | |||
if (type_iter != remove_node_types_to_arg_.end()) { | |||
GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str()); | |||
return IsolateAndDeleteNode(node, type_iter->second); | |||
} | |||
for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) { | |||
if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) { | |||
GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str()); | |||
return IsolateAndDeleteNode(node, attr_name_to_arg.second); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list<int> arg) { | |||
remove_node_types_to_arg_[node_type] = std::move(arg); | |||
return *this; | |||
} | |||
RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list<int> arg) { | |||
remove_node_attr_names_to_arg_[attr_name] = std::move(arg); | |||
return *this; | |||
} | |||
} // namespace ge |
@@ -1,32 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_REMOVE_NODES_PASS_H_ | |||
#define GE_REMOVE_NODES_PASS_H_ | |||
#include "graph/passes/base_pass.h" | |||
namespace ge { | |||
class RemoveNodesPass : public BaseNodePass { | |||
public: | |||
Status Run(NodePtr &node) override; | |||
RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list<int> arg = {0}); | |||
RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list<int> arg = {0}); | |||
private: | |||
std::map<std::string, std::initializer_list<int>> remove_node_types_to_arg_; | |||
std::map<std::string, std::initializer_list<int>> remove_node_attr_names_to_arg_; | |||
}; | |||
} // namespace ge | |||
#endif //GE_REMOVE_NODES_PASS_H_ |
@@ -1,134 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/unused_op_remove_pass.h" | |||
#include <queue> | |||
#include <set> | |||
#include <string> | |||
#include <vector> | |||
#include "common/debug/log.h" | |||
#include "common/op/ge_op_utils.h" | |||
#include "common/types.h" | |||
#include "common/util.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "inc/pass_manager.h" | |||
#include "graph/passes/isolated_op_remove_pass.h" | |||
using domi::SUCCESS; | |||
namespace ge { | |||
const std::set<std::string> kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; | |||
const std::set<std::string> kOtherRemoveOpSet = {DROPOUT}; | |||
Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) { | |||
GE_CHECK_NOTNULL(graph); | |||
std::set<std::string> remove_op_set; | |||
vector<NodePtr> nodes_to_be_deleted; | |||
if (fmktype_ == TENSORFLOW) { | |||
remove_op_set = kRemoveOpSet; | |||
} else { | |||
remove_op_set = kOtherRemoveOpSet; | |||
} | |||
for (auto &node : graph->GetDirectNode()) { | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
std::string op_type_str = node->GetOpDesc()->GetType(); | |||
if (remove_op_set.count(op_type_str)) { | |||
if (IsExceptions(node)) { | |||
continue; | |||
} | |||
for (auto &out_anchor : node->GetAllOutDataAnchors()) { | |||
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
NodePtr dst_node = in_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||
int dst_index = in_anchor->GetIdx(); | |||
std::vector<bool> list_bool; | |||
GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||
list_bool = dst_node->GetOpDesc()->GetIsInputConst(); | |||
GE_IF_BOOL_EXEC(list_bool.size() == 0, continue); | |||
list_bool.erase(list_bool.begin() + dst_index); | |||
dst_node->GetOpDesc()->SetIsInputConst(list_bool); | |||
} | |||
} | |||
if (op_type_str == ASSERT) { | |||
GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed"); | |||
} else { | |||
GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed"); | |||
} | |||
} | |||
} | |||
for (auto &node : nodes_to_be_deleted) { | |||
for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) { | |||
inAnchor->UnlinkAll(); | |||
} | |||
for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) { | |||
outAnchorPtr->UnlinkAll(); | |||
} | |||
if (node->GetOutControlAnchor() != nullptr) { | |||
node->GetOutControlAnchor()->UnlinkAll(); | |||
} | |||
GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node, | |||
vector<NodePtr> &node_vec) { | |||
GE_CHECK_NOTNULL(graph); | |||
GE_CHECK_NOTNULL(node); | |||
node_vec.push_back(node); | |||
std::queue<NodePtr> node_queue; | |||
for (auto &src_node : node->GetInDataNodes()) { | |||
if (src_node->GetOutDataNodesSize() == 1) { | |||
node_queue.push(src_node); | |||
} | |||
} | |||
while (!node_queue.empty()) { | |||
NodePtr temp = node_queue.front(); | |||
node_queue.pop(); | |||
for (auto &src_node : temp->GetInDataNodes()) { | |||
if (src_node->GetOutDataNodesSize() == 1) { | |||
node_queue.push(src_node); | |||
} | |||
} | |||
node_vec.push_back(temp); | |||
} | |||
return SUCCESS; | |||
} | |||
bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) { | |||
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); | |||
auto op_def = node->GetOpDesc(); | |||
GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr"); | |||
// permute optimised in permute_pass.cpp | |||
if (op_def->GetType() == PERMUTE) { | |||
GE_IF_BOOL_EXEC( | |||
(node->GetInDataNodes().size() != 0 && | |||
(node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr && | |||
node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)), | |||
return false); | |||
return true; | |||
} | |||
return false; | |||
} | |||
} // namespace ge |
@@ -1,41 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||
#define GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||
#include <string> | |||
#include <vector> | |||
#include "framework/common/ge_types.h" | |||
#include "inc/graph_pass.h" | |||
namespace ge { | |||
class UnusedOpRemovePass : public GraphPass { | |||
public: | |||
explicit UnusedOpRemovePass(FrameworkType type) : fmktype_(type) {} | |||
~UnusedOpRemovePass() {} | |||
Status Run(ge::ComputeGraphPtr graph) override; | |||
bool IsExceptions(const ge::NodePtr &node); | |||
private: | |||
Status CollectParentNode(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node, | |||
std::vector<ge::NodePtr> &node_vec); | |||
std::vector<std::string> v_remove_ops; | |||
FrameworkType fmktype_; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ |
@@ -1,119 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/variable_format_pass.h" | |||
#include <map> | |||
#include <set> | |||
#include <string> | |||
#include "framework/common/debug/ge_log.h" | |||
namespace ge { | |||
Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { | |||
GE_CHECK_NOTNULL(graph); | |||
for (auto &node : graph->GetDirectNode()) { | |||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); | |||
ge::NodePtr use_node = nullptr; | |||
if (GetApplyMomentumOpByVariableInput(node, use_node)) { | |||
GE_CHK_STATUS_RET(UpdateVariableOutFormat(node, use_node), "update variable out format failed"); | |||
GE_CHK_STATUS_RET(UpdateApplyMomentumInputFormat(use_node), "update apply momentum input format failed"); | |||
} | |||
} | |||
return domi::SUCCESS; | |||
} | |||
bool VariableFormatPass::GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||
GE_IF_BOOL_EXEC(var_node == nullptr, return false); | |||
std::map<std::string, std::set<int>> confirm_ops = {{"ApplyMomentum", {1}}}; | |||
for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { | |||
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
GE_IF_BOOL_EXEC(ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node), return true); | |||
} | |||
} | |||
return false; | |||
} | |||
bool VariableFormatPass::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||
const map<string, std::set<int>> &confirm_ops, | |||
ge::NodePtr &use_node) { | |||
GE_IF_BOOL_EXEC(in_anchor == nullptr, return false); | |||
ge::NodePtr dst_node = in_anchor->GetOwnerNode(); | |||
ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(dst_op_desc == nullptr, return false); | |||
const string &dst_type = dst_op_desc->GetType(); | |||
int input_index = in_anchor->GetIdx(); | |||
GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), | |||
dst_type.c_str(), input_index); | |||
GE_IF_BOOL_EXEC(confirm_ops.count(dst_type) > 0, | |||
GE_IF_BOOL_EXEC(confirm_ops.at(dst_type).count(input_index) > 0, use_node = dst_node; return true);); | |||
return false; | |||
} | |||
Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||
GE_CHECK_NOTNULL(var_node); | |||
GE_CHECK_NOTNULL(use_node); | |||
ge::OpDescPtr op_desc_ptr = use_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc_ptr); | |||
GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)); | |||
GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||
NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
if (in_node != nullptr) { | |||
string in_op_type = in_node->GetType(); | |||
if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && | |||
(in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { | |||
ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||
ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); | |||
if (cur_op_desc_ptr != nullptr) { | |||
cur_op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||
cur_op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||
} | |||
} | |||
} | |||
return domi::SUCCESS; | |||
} | |||
Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
ge::OpDescPtr op_desc_ptr = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc_ptr); | |||
GE_CHECK_NOTNULL(node->GetInDataAnchor(0)); | |||
GE_CHECK_NOTNULL(node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||
GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(0)); | |||
GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(1)); | |||
GE_CHECK_NOTNULL(op_desc_ptr->MutableOutputDesc(0)); | |||
NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
if (in_node != nullptr) { | |||
string in_op_type = in_node->GetType(); | |||
if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { | |||
ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||
op_desc_ptr->MutableInputDesc(0)->SetFormat(format); | |||
op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); | |||
op_desc_ptr->MutableInputDesc(1)->SetFormat(format); | |||
op_desc_ptr->MutableInputDesc(1)->SetOriginFormat(format); | |||
op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||
op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||
} | |||
} | |||
return domi::SUCCESS; | |||
} | |||
} // namespace ge |
@@ -1,44 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||
#define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||
#include <map> | |||
#include <set> | |||
#include <string> | |||
#include "graph/types.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "inc/graph_pass.h" | |||
namespace ge { | |||
class VariableFormatPass : public GraphPass { | |||
public: | |||
Status Run(ge::ComputeGraphPtr graph) override; | |||
private: | |||
bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||
bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||
const map<string, std::set<int> > &confirm_ops, ge::NodePtr &use_node); | |||
Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node); | |||
Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ |
@@ -216,12 +216,10 @@ set(COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" | |||
@@ -263,7 +261,6 @@ set(COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/variable_format_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | |||
@@ -495,8 +492,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" | |||
@@ -670,8 +665,7 @@ set(PASS_TEST_FILES | |||
"graph/passes/permute_pass_unittest.cc" | |||
"graph/passes/print_op_pass_unittest.cc" | |||
"graph/passes/shape_operate_op_remove_pass_unittest.cc" | |||
"graph/passes/unused_and_isolated_op_remove_pass_unittest.cc" | |||
"graph/passes/variable_op_pass_unittest.cc" | |||
"graph/passes/variable_op_pass_unittest.cc" | |||
"graph/passes/base_pass_unittest.cc" | |||
"graph/passes/addn_pass_unittest.cc" | |||
"graph/passes/save_pass_unittest.cc" | |||
@@ -1,191 +0,0 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/unused_op_remove_pass.h" | |||
#include <gtest/gtest.h> | |||
#include "graph/passes/isolated_op_remove_pass.h" | |||
#include "pass_manager.h" | |||
using namespace ge; | |||
class UtestGraphPassesUnusedAndIsolatedOpRemovePass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
NodePtr AddNode(ComputeGraphPtr graph, const string &name, const string &type, int32_t in_anchors_num = 1, | |||
int32_t out_anchors_num = 1) { | |||
GeTensorDesc tensor_desc; | |||
OpDescPtr op_desc = make_shared<OpDesc>(name, type); | |||
for (int32_t i = 0; i < in_anchors_num; i++) { | |||
op_desc->AddInputDesc(tensor_desc); | |||
} | |||
for (int32_t i = 0; i < out_anchors_num; i++) { | |||
op_desc->AddOutputDesc(tensor_desc); | |||
} | |||
NodePtr node = graph->AddNode(op_desc); | |||
return node; | |||
} | |||
}; | |||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
NodePtr reshape_node = AddNode(graph, "reshape1", RESHAPE); | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node = graph->FindNode("transpose1"); | |||
EXPECT_EQ(transpose_node, found_node); | |||
} | |||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
NodePtr squeeze_node = AddNode(graph, "squeeze1", SQUEEZE); | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node = graph->FindNode("transpose1"); | |||
EXPECT_EQ(transpose_node, found_node); | |||
} | |||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
vector<int64_t> order_list = {0, 2, 3, 1}; | |||
AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); | |||
AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); | |||
NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node0 = graph->FindNode("transpose1"); | |||
NodePtr found_node = graph->FindNode("conv1"); | |||
EXPECT_EQ(conv_node, found_node); | |||
} | |||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE); | |||
vector<int64_t> order_list = {0, 1, 3, 2}; | |||
AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list); | |||
AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT); | |||
NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION); | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node0 = graph->FindNode("transpose1"); | |||
EXPECT_EQ(transpose_node, found_node0); | |||
NodePtr found_node = graph->FindNode("conv1"); | |||
EXPECT_EQ(conv_node, found_node); | |||
} | |||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
NodePtr conv3_node = AddNode(graph, "cast3", CAST); | |||
NodePtr transpose_node = AddNode(graph, "cast1", CAST); | |||
NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
} | |||
TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
vector<NodePtr> node_vec; | |||
NodePtr data_node = AddNode(graph, "DATA", DATA); | |||
NodePtr conv3_node = AddNode(graph, "cast3", CAST); | |||
NodePtr transpose_node = AddNode(graph, "cast1", CAST); | |||
NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST); | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
} |