Browse Source

!1564 Bugfix: Delete abandoned optimize pass

From: @hugo1
Reviewed-by: @sheng-nan,@ji_chen
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
b3db854eef
13 changed files with 1 additions and 692 deletions
  1. +0
    -6
      ge/CMakeLists.txt
  2. +0
    -3
      ge/ge_inference.mk
  3. +0
    -3
      ge/ge_runner.mk
  4. +0
    -37
      ge/graph/passes/isolated_op_remove_pass.cc
  5. +0
    -28
      ge/graph/passes/isolated_op_remove_pass.h
  6. +0
    -47
      ge/graph/passes/remove_nodes_pass.cc
  7. +0
    -32
      ge/graph/passes/remove_nodes_pass.h
  8. +0
    -134
      ge/graph/passes/unused_op_remove_pass.cc
  9. +0
    -41
      ge/graph/passes/unused_op_remove_pass.h
  10. +0
    -119
      ge/graph/passes/variable_format_pass.cc
  11. +0
    -44
      ge/graph/passes/variable_format_pass.h
  12. +1
    -7
      tests/ut/ge/CMakeLists.txt
  13. +0
    -191
      tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc

+ 0
- 6
ge/CMakeLists.txt View File

@@ -270,7 +270,6 @@ set(TRAIN_SRC_LIST
"graph/passes/identity_pass.cc" "graph/passes/identity_pass.cc"
"graph/passes/ref_identity_delete_op_pass.cc" "graph/passes/ref_identity_delete_op_pass.cc"
"graph/passes/infershape_pass.cc" "graph/passes/infershape_pass.cc"
"graph/passes/isolated_op_remove_pass.cc"
"graph/passes/iterator_op_pass.cc" "graph/passes/iterator_op_pass.cc"
"graph/passes/link_gen_mask_nodes_pass.cc" "graph/passes/link_gen_mask_nodes_pass.cc"
"graph/passes/merge_pass.cc" "graph/passes/merge_pass.cc"
@@ -317,13 +316,11 @@ set(TRAIN_SRC_LIST
"graph/passes/transop_without_reshape_fusion_pass.cc" "graph/passes/transop_without_reshape_fusion_pass.cc"
"graph/passes/transpose_transdata_pass.cc" "graph/passes/transpose_transdata_pass.cc"
"graph/passes/unused_const_pass.cc" "graph/passes/unused_const_pass.cc"
"graph/passes/unused_op_remove_pass.cc"
"graph/passes/var_is_initialized_op_pass.cc" "graph/passes/var_is_initialized_op_pass.cc"
"graph/passes/parallel_concat_start_op_pass.cc" "graph/passes/parallel_concat_start_op_pass.cc"
"graph/passes/cond_pass.cc" "graph/passes/cond_pass.cc"
"graph/passes/cond_remove_pass.cc" "graph/passes/cond_remove_pass.cc"
"graph/passes/for_pass.cc" "graph/passes/for_pass.cc"
"graph/passes/variable_format_pass.cc"
"graph/passes/variable_op_pass.cc" "graph/passes/variable_op_pass.cc"
"graph/passes/variable_prepare_op_pass.cc" "graph/passes/variable_prepare_op_pass.cc"
"graph/passes/variable_ref_delete_op_pass.cc" "graph/passes/variable_ref_delete_op_pass.cc"
@@ -522,12 +519,10 @@ set(INFER_SRC_LIST
"graph/passes/dimension_adjust_pass.cc" "graph/passes/dimension_adjust_pass.cc"
"graph/passes/get_original_format_pass.cc" "graph/passes/get_original_format_pass.cc"
"graph/passes/shape_operate_op_remove_pass.cc" "graph/passes/shape_operate_op_remove_pass.cc"
"graph/passes/unused_op_remove_pass.cc"
"graph/passes/assert_pass.cc" "graph/passes/assert_pass.cc"
"graph/passes/dropout_pass.cc" "graph/passes/dropout_pass.cc"
"graph/passes/infershape_pass.cc" "graph/passes/infershape_pass.cc"
"graph/passes/unused_const_pass.cc" "graph/passes/unused_const_pass.cc"
"graph/passes/isolated_op_remove_pass.cc"
"graph/passes/permute_pass.cc" "graph/passes/permute_pass.cc"
"graph/passes/ctrl_edge_transfer_pass.cc" "graph/passes/ctrl_edge_transfer_pass.cc"
"graph/passes/end_of_sequence_add_control_pass.cc" "graph/passes/end_of_sequence_add_control_pass.cc"
@@ -610,7 +605,6 @@ set(INFER_SRC_LIST
"graph/passes/switch_logic_remove_pass.cc" "graph/passes/switch_logic_remove_pass.cc"
"graph/passes/switch_data_edges_bypass.cc" "graph/passes/switch_data_edges_bypass.cc"
"graph/passes/merge_pass.cc" "graph/passes/merge_pass.cc"
"graph/passes/variable_format_pass.cc"
"graph/passes/variable_op_pass.cc" "graph/passes/variable_op_pass.cc"
"graph/passes/cast_remove_pass.cc" "graph/passes/cast_remove_pass.cc"
"graph/passes/transpose_transdata_pass.cc" "graph/passes/transpose_transdata_pass.cc"


+ 0
- 3
ge/ge_inference.mk View File

@@ -122,12 +122,10 @@ OMG_HOST_SRC_FILES := \
graph/passes/dimension_adjust_pass.cc \ graph/passes/dimension_adjust_pass.cc \
graph/passes/get_original_format_pass.cc \ graph/passes/get_original_format_pass.cc \
graph/passes/shape_operate_op_remove_pass.cc \ graph/passes/shape_operate_op_remove_pass.cc \
graph/passes/unused_op_remove_pass.cc \
graph/passes/assert_pass.cc \ graph/passes/assert_pass.cc \
graph/passes/dropout_pass.cc \ graph/passes/dropout_pass.cc \
graph/passes/infershape_pass.cc \ graph/passes/infershape_pass.cc \
graph/passes/unused_const_pass.cc \ graph/passes/unused_const_pass.cc \
graph/passes/isolated_op_remove_pass.cc \
graph/passes/permute_pass.cc \ graph/passes/permute_pass.cc \
graph/passes/ctrl_edge_transfer_pass.cc \ graph/passes/ctrl_edge_transfer_pass.cc \
graph/passes/end_of_sequence_add_control_pass.cc \ graph/passes/end_of_sequence_add_control_pass.cc \
@@ -209,7 +207,6 @@ OMG_HOST_SRC_FILES := \
graph/passes/switch_logic_remove_pass.cc \ graph/passes/switch_logic_remove_pass.cc \
graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_data_edges_bypass.cc \
graph/passes/merge_pass.cc \ graph/passes/merge_pass.cc \
graph/passes/variable_format_pass.cc \
graph/passes/variable_op_pass.cc \ graph/passes/variable_op_pass.cc \
graph/passes/cast_remove_pass.cc \ graph/passes/cast_remove_pass.cc \
graph/passes/transpose_transdata_pass.cc \ graph/passes/transpose_transdata_pass.cc \


+ 0
- 3
ge/ge_runner.mk View File

@@ -187,7 +187,6 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/identity_pass.cc \ graph/passes/identity_pass.cc \
graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/ref_identity_delete_op_pass.cc \
graph/passes/infershape_pass.cc \ graph/passes/infershape_pass.cc \
graph/passes/isolated_op_remove_pass.cc \
graph/passes/iterator_op_pass.cc \ graph/passes/iterator_op_pass.cc \
graph/passes/link_gen_mask_nodes_pass.cc \ graph/passes/link_gen_mask_nodes_pass.cc \
graph/passes/merge_pass.cc \ graph/passes/merge_pass.cc \
@@ -233,13 +232,11 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/transop_without_reshape_fusion_pass.cc \ graph/passes/transop_without_reshape_fusion_pass.cc \
graph/passes/transpose_transdata_pass.cc \ graph/passes/transpose_transdata_pass.cc \
graph/passes/unused_const_pass.cc \ graph/passes/unused_const_pass.cc \
graph/passes/unused_op_remove_pass.cc \
graph/passes/var_is_initialized_op_pass.cc \ graph/passes/var_is_initialized_op_pass.cc \
graph/passes/parallel_concat_start_op_pass.cc \ graph/passes/parallel_concat_start_op_pass.cc \
graph/passes/cond_pass.cc \ graph/passes/cond_pass.cc \
graph/passes/cond_remove_pass.cc \ graph/passes/cond_remove_pass.cc \
graph/passes/for_pass.cc \ graph/passes/for_pass.cc \
graph/passes/variable_format_pass.cc \
graph/passes/variable_op_pass.cc \ graph/passes/variable_op_pass.cc \
graph/passes/variable_prepare_op_pass.cc \ graph/passes/variable_prepare_op_pass.cc \
graph/passes/variable_ref_delete_op_pass.cc \ graph/passes/variable_ref_delete_op_pass.cc \


+ 0
- 37
ge/graph/passes/isolated_op_remove_pass.cc View File

@@ -1,37 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/isolated_op_remove_pass.h"

#include "common/debug/log.h"
#include "common/types.h"
#include "common/util.h"

namespace ge {
Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
for (NodePtr &node_ptr : graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, continue);
if (node_ptr->GetInDataNodes().size() == 0 && node_ptr->GetOutAllNodes().size() == 0 &&
!(node_ptr->GetOpDesc()->HasAttr(TO_BE_OUTPUT))) {
GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node_ptr), "remove graph node [%s] fail",
node_ptr->GetOpDesc()->GetName().c_str());
}
}

return SUCCESS;
}
} // namespace ge

+ 0
- 28
ge/graph/passes/isolated_op_remove_pass.h View File

@@ -1,28 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_
#define GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_

#include "inc/graph_pass.h"

namespace ge {
class IsolatedOpRemovePass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_

+ 0
- 47
ge/graph/passes/remove_nodes_pass.cc View File

@@ -1,47 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "remove_nodes_pass.h"
#include "debug/ge_log.h"
#include "inc/framework/common/util.h"
#include "inc/graph/utils/node_utils.h"

namespace ge {
Status RemoveNodesPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto node_type = NodeUtils::GetNodeType(*node);
auto type_iter = remove_node_types_to_arg_.find(node_type);
if (type_iter != remove_node_types_to_arg_.end()) {
GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str());
return IsolateAndDeleteNode(node, type_iter->second);
}
for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) {
if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) {
GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str());
return IsolateAndDeleteNode(node, attr_name_to_arg.second);
}
}

return SUCCESS;
}
RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list<int> arg) {
remove_node_types_to_arg_[node_type] = std::move(arg);
return *this;
}
RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list<int> arg) {
remove_node_attr_names_to_arg_[attr_name] = std::move(arg);
return *this;
}
} // namespace ge

+ 0
- 32
ge/graph/passes/remove_nodes_pass.h View File

@@ -1,32 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_REMOVE_NODES_PASS_H_
#define GE_REMOVE_NODES_PASS_H_
#include "graph/passes/base_pass.h"

namespace ge {
class RemoveNodesPass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list<int> arg = {0});
RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list<int> arg = {0});

private:
std::map<std::string, std::initializer_list<int>> remove_node_types_to_arg_;
std::map<std::string, std::initializer_list<int>> remove_node_attr_names_to_arg_;
};
} // namespace ge
#endif //GE_REMOVE_NODES_PASS_H_

+ 0
- 134
ge/graph/passes/unused_op_remove_pass.cc View File

@@ -1,134 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/unused_op_remove_pass.h"
#include <queue>
#include <set>
#include <string>
#include <vector>
#include "common/debug/log.h"
#include "common/op/ge_op_utils.h"
#include "common/types.h"
#include "common/util.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "inc/pass_manager.h"
#include "graph/passes/isolated_op_remove_pass.h"

using domi::SUCCESS;

namespace ge {
const std::set<std::string> kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT};
const std::set<std::string> kOtherRemoveOpSet = {DROPOUT};

Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
std::set<std::string> remove_op_set;
vector<NodePtr> nodes_to_be_deleted;
if (fmktype_ == TENSORFLOW) {
remove_op_set = kRemoveOpSet;
} else {
remove_op_set = kOtherRemoveOpSet;
}

for (auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
std::string op_type_str = node->GetOpDesc()->GetType();
if (remove_op_set.count(op_type_str)) {
if (IsExceptions(node)) {
continue;
}
for (auto &out_anchor : node->GetAllOutDataAnchors()) {
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
NodePtr dst_node = in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node->GetOpDesc());
int dst_index = in_anchor->GetIdx();
std::vector<bool> list_bool;
GE_CHECK_NOTNULL(dst_node->GetOpDesc());
list_bool = dst_node->GetOpDesc()->GetIsInputConst();
GE_IF_BOOL_EXEC(list_bool.size() == 0, continue);
list_bool.erase(list_bool.begin() + dst_index);
dst_node->GetOpDesc()->SetIsInputConst(list_bool);
}
}
if (op_type_str == ASSERT) {
GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed");
} else {
GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed");
}
}
}
for (auto &node : nodes_to_be_deleted) {
for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) {
inAnchor->UnlinkAll();
}
for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) {
outAnchorPtr->UnlinkAll();
}
if (node->GetOutControlAnchor() != nullptr) {
node->GetOutControlAnchor()->UnlinkAll();
}
GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str());
}

return SUCCESS;
}

Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node,
vector<NodePtr> &node_vec) {
GE_CHECK_NOTNULL(graph);
GE_CHECK_NOTNULL(node);
node_vec.push_back(node);
std::queue<NodePtr> node_queue;

for (auto &src_node : node->GetInDataNodes()) {
if (src_node->GetOutDataNodesSize() == 1) {
node_queue.push(src_node);
}
}

while (!node_queue.empty()) {
NodePtr temp = node_queue.front();
node_queue.pop();

for (auto &src_node : temp->GetInDataNodes()) {
if (src_node->GetOutDataNodesSize() == 1) {
node_queue.push(src_node);
}
}
node_vec.push_back(temp);
}

return SUCCESS;
}

bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) {
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
auto op_def = node->GetOpDesc();
GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr");
// permute optimised in permute_pass.cpp
if (op_def->GetType() == PERMUTE) {
GE_IF_BOOL_EXEC(
(node->GetInDataNodes().size() != 0 &&
(node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr &&
node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)),
return false);
return true;
}
return false;
}
} // namespace ge

+ 0
- 41
ge/graph/passes/unused_op_remove_pass.h View File

@@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_
#define GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_

#include <string>
#include <vector>
#include "framework/common/ge_types.h"
#include "inc/graph_pass.h"

namespace ge {
class UnusedOpRemovePass : public GraphPass {
public:
explicit UnusedOpRemovePass(FrameworkType type) : fmktype_(type) {}
~UnusedOpRemovePass() {}
Status Run(ge::ComputeGraphPtr graph) override;
bool IsExceptions(const ge::NodePtr &node);

private:
Status CollectParentNode(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node,
std::vector<ge::NodePtr> &node_vec);
std::vector<std::string> v_remove_ops;
FrameworkType fmktype_;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_

+ 0
- 119
ge/graph/passes/variable_format_pass.cc View File

@@ -1,119 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/variable_format_pass.h"
#include <map>
#include <set>
#include <string>
#include "framework/common/debug/ge_log.h"

namespace ge {
Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);

for (auto &node : graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue);

ge::NodePtr use_node = nullptr;
if (GetApplyMomentumOpByVariableInput(node, use_node)) {
GE_CHK_STATUS_RET(UpdateVariableOutFormat(node, use_node), "update variable out format failed");
GE_CHK_STATUS_RET(UpdateApplyMomentumInputFormat(use_node), "update apply momentum input format failed");
}
}

return domi::SUCCESS;
}

bool VariableFormatPass::GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node) {
GE_IF_BOOL_EXEC(var_node == nullptr, return false);

std::map<std::string, std::set<int>> confirm_ops = {{"ApplyMomentum", {1}}};
for (auto &out_anchor : var_node->GetAllOutDataAnchors()) {
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node), return true);
}
}

return false;
}

bool VariableFormatPass::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor,
const map<string, std::set<int>> &confirm_ops,
ge::NodePtr &use_node) {
GE_IF_BOOL_EXEC(in_anchor == nullptr, return false);
ge::NodePtr dst_node = in_anchor->GetOwnerNode();
ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc();
GE_IF_BOOL_EXEC(dst_op_desc == nullptr, return false);
const string &dst_type = dst_op_desc->GetType();
int input_index = in_anchor->GetIdx();

GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(),
dst_type.c_str(), input_index);

GE_IF_BOOL_EXEC(confirm_ops.count(dst_type) > 0,
GE_IF_BOOL_EXEC(confirm_ops.at(dst_type).count(input_index) > 0, use_node = dst_node; return true););
return false;
}

Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node) {
GE_CHECK_NOTNULL(var_node);
GE_CHECK_NOTNULL(use_node);
ge::OpDescPtr op_desc_ptr = use_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc_ptr);
GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0));
GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)->GetPeerOutAnchor());
NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode();
if (in_node != nullptr) {
string in_op_type = in_node->GetType();
if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) &&
(in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) {
ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat();
ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc();
if (cur_op_desc_ptr != nullptr) {
cur_op_desc_ptr->MutableOutputDesc(0)->SetFormat(format);
cur_op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format);
}
}
}
return domi::SUCCESS;
}

Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &node) {
GE_CHECK_NOTNULL(node);
ge::OpDescPtr op_desc_ptr = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc_ptr);
GE_CHECK_NOTNULL(node->GetInDataAnchor(0));
GE_CHECK_NOTNULL(node->GetInDataAnchor(0)->GetPeerOutAnchor());
GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(0));
GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(1));
GE_CHECK_NOTNULL(op_desc_ptr->MutableOutputDesc(0));
NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode();
if (in_node != nullptr) {
string in_op_type = in_node->GetType();
if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) {
ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat();
op_desc_ptr->MutableInputDesc(0)->SetFormat(format);
op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format);
op_desc_ptr->MutableInputDesc(1)->SetFormat(format);
op_desc_ptr->MutableInputDesc(1)->SetOriginFormat(format);
op_desc_ptr->MutableOutputDesc(0)->SetFormat(format);
op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format);
}
}
return domi::SUCCESS;
}
} // namespace ge

+ 0
- 44
ge/graph/passes/variable_format_pass.h View File

@@ -1,44 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_
#define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_

#include <map>
#include <set>
#include <string>
#include "graph/types.h"
#include "graph/utils/op_desc_utils.h"
#include "inc/graph_pass.h"

namespace ge {
class VariableFormatPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override;

private:
bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node);

bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor,
const map<string, std::set<int> > &confirm_ops, ge::NodePtr &use_node);

Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node);

Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node);
};
} // namespace ge

#endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_

+ 1
- 7
tests/ut/ge/CMakeLists.txt View File

@@ -216,12 +216,10 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc"
@@ -263,7 +261,6 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc"
"${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/variable_format_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc"
@@ -495,8 +492,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/unused_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/isolated_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc"
@@ -670,8 +665,7 @@ set(PASS_TEST_FILES
"graph/passes/permute_pass_unittest.cc" "graph/passes/permute_pass_unittest.cc"
"graph/passes/print_op_pass_unittest.cc" "graph/passes/print_op_pass_unittest.cc"
"graph/passes/shape_operate_op_remove_pass_unittest.cc" "graph/passes/shape_operate_op_remove_pass_unittest.cc"
"graph/passes/unused_and_isolated_op_remove_pass_unittest.cc"
"graph/passes/variable_op_pass_unittest.cc"
"graph/passes/variable_op_pass_unittest.cc"
"graph/passes/base_pass_unittest.cc" "graph/passes/base_pass_unittest.cc"
"graph/passes/addn_pass_unittest.cc" "graph/passes/addn_pass_unittest.cc"
"graph/passes/save_pass_unittest.cc" "graph/passes/save_pass_unittest.cc"


+ 0
- 191
tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc View File

@@ -1,191 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/unused_op_remove_pass.h"

#include <gtest/gtest.h>
#include "graph/passes/isolated_op_remove_pass.h"
#include "pass_manager.h"

using namespace ge;

class UtestGraphPassesUnusedAndIsolatedOpRemovePass : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}

NodePtr AddNode(ComputeGraphPtr graph, const string &name, const string &type, int32_t in_anchors_num = 1,
int32_t out_anchors_num = 1) {
GeTensorDesc tensor_desc;
OpDescPtr op_desc = make_shared<OpDesc>(name, type);
for (int32_t i = 0; i < in_anchors_num; i++) {
op_desc->AddInputDesc(tensor_desc);
}
for (int32_t i = 0; i < out_anchors_num; i++) {
op_desc->AddOutputDesc(tensor_desc);
}

NodePtr node = graph->AddNode(op_desc);
return node;
}
};

TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");

NodePtr data_node = AddNode(graph, "DATA", DATA);
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE);
NodePtr reshape_node = AddNode(graph, "reshape1", RESHAPE);

GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0));
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0));

ge::UnusedOpRemovePass unused_pass(TENSORFLOW);
ge::IsolatedOpRemovePass isolate_pass;
std::vector<std::pair<string, GraphPass*>> passes;
passes.emplace_back("", &isolate_pass);
passes.emplace_back("", &unused_pass);
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(SUCCESS, status);
NodePtr found_node = graph->FindNode("transpose1");
EXPECT_EQ(transpose_node, found_node);
}

TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");

NodePtr data_node = AddNode(graph, "DATA", DATA);
NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE);
NodePtr squeeze_node = AddNode(graph, "squeeze1", SQUEEZE);

GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0));
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0));

ge::UnusedOpRemovePass unused_pass(TENSORFLOW);
ge::IsolatedOpRemovePass isolate_pass;
std::vector<std::pair<string, GraphPass*>> passes;
passes.emplace_back("", &isolate_pass);
passes.emplace_back("", &unused_pass);
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(SUCCESS, status);
NodePtr found_node = graph->FindNode("transpose1");
EXPECT_EQ(transpose_node, found_node);
}

TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");

NodePtr data_node = AddNode(graph, "DATA", DATA);

NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE);
vector<int64_t> order_list = {0, 2, 3, 1};
AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list);
AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT);

NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION);

GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0));
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0));

NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION);
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0));

ge::UnusedOpRemovePass unused_pass(TENSORFLOW);
ge::IsolatedOpRemovePass isolate_pass;
std::vector<std::pair<string, GraphPass*>> passes;
passes.emplace_back("", &isolate_pass);
passes.emplace_back("", &unused_pass);
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(SUCCESS, status);
NodePtr found_node0 = graph->FindNode("transpose1");
NodePtr found_node = graph->FindNode("conv1");
EXPECT_EQ(conv_node, found_node);
}

TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");

NodePtr data_node = AddNode(graph, "DATA", DATA);

NodePtr transpose_node = AddNode(graph, "transpose1", PERMUTE);
vector<int64_t> order_list = {0, 1, 3, 2};
AttrUtils::SetListInt(transpose_node->GetOpDesc(), PERMUTE_ATTR_ORDER, order_list);
AttrUtils::SetInt(transpose_node->GetOpDesc(), ATTR_NAME_FORMAT, (int64_t)DT_FLOAT);

NodePtr conv_node = AddNode(graph, "conv1", CONVOLUTION);

GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0));
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0));

NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION);
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0));

ge::UnusedOpRemovePass unused_pass(TENSORFLOW);
ge::IsolatedOpRemovePass isolate_pass;
std::vector<std::pair<string, GraphPass*>> passes;
passes.emplace_back("", &isolate_pass);
passes.emplace_back("", &unused_pass);
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(SUCCESS, status);
NodePtr found_node0 = graph->FindNode("transpose1");
EXPECT_EQ(transpose_node, found_node0);
NodePtr found_node = graph->FindNode("conv1");
EXPECT_EQ(conv_node, found_node);
}

TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");

NodePtr data_node = AddNode(graph, "DATA", DATA);
NodePtr conv3_node = AddNode(graph, "cast3", CAST);
NodePtr transpose_node = AddNode(graph, "cast1", CAST);
NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST);

GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0));
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0));
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0));

ge::UnusedOpRemovePass unused_pass(TENSORFLOW);
ge::IsolatedOpRemovePass isolate_pass;
std::vector<std::pair<string, GraphPass*>> passes;
passes.emplace_back("", &isolate_pass);
passes.emplace_back("", &unused_pass);
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(SUCCESS, status);
}

TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
vector<NodePtr> node_vec;

NodePtr data_node = AddNode(graph, "DATA", DATA);
NodePtr conv3_node = AddNode(graph, "cast3", CAST);
NodePtr transpose_node = AddNode(graph, "cast1", CAST);
NodePtr transpose_node_1 = AddNode(graph, "cast2", CAST);

GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv3_node->GetInDataAnchor(0));
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0));
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0));

ge::UnusedOpRemovePass unused_pass(TENSORFLOW);
ge::IsolatedOpRemovePass isolate_pass;
std::vector<std::pair<string, GraphPass*>> passes;
passes.emplace_back("", &isolate_pass);
passes.emplace_back("", &unused_pass);
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(SUCCESS, status);
}

Loading…
Cancel
Save