Browse Source

modified: ge/graph/passes/merge_pass.cc

modified:   ge/graph/passes/merge_pass.h
tags/v1.1.0
zhaoxinxin 3 years ago
parent
commit
b8cf5089f1
2 changed files with 0 additions and 37 deletions
  1. +0
    -36
      ge/graph/passes/merge_pass.cc
  2. +0
    -1
      ge/graph/passes/merge_pass.h

+ 0
- 36
ge/graph/passes/merge_pass.cc View File

@@ -34,11 +34,6 @@ using domi::SUCCESS;
namespace ge {
const int kValueIndexOutputIndex = 1;

bool IsEmptyTensor(const GeShape &shape) {
const auto &dims = shape.GetDims();
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; });
}

Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running");
if (node == nullptr) {
@@ -58,11 +53,6 @@ Status MergePass::Run(NodePtr &node) {
return PARAM_INVALID;
}

if (OptimizeEmptyTensorInput(node) != SUCCESS) {
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str());
return FAILED;
}

auto in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) {
case 0: {
@@ -212,30 +202,4 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
}
return true;
}

Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_data_anchor == nullptr) {
continue;
}
if ((peer_data_anchor->GetOwnerNode() == nullptr) ||
(peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) {
continue;
}
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc();
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) {
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
return FAILED;
}
GELOGD("Remove data edge %s:%d->%s:%d",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
}
}
return SUCCESS;
}
} // namespace ge

+ 0
- 1
ge/graph/passes/merge_pass.h View File

@@ -29,7 +29,6 @@ class MergePass : public BaseNodePass {
Status ChangeIndexToConstant(NodePtr &node, int &value_index);
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc);
bool IsMergeInputNeedOptimized(NodePtr &node) const;
static Status OptimizeEmptyTensorInput(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_

Loading…
Cancel
Save