|
|
@@ -34,11 +34,6 @@ using domi::SUCCESS; |
|
|
|
namespace ge { |
|
|
|
const int kValueIndexOutputIndex = 1; |
|
|
|
|
|
|
|
bool IsEmptyTensor(const GeShape &shape) { |
|
|
|
const auto &dims = shape.GetDims(); |
|
|
|
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; }); |
|
|
|
} |
|
|
|
|
|
|
|
Status MergePass::Run(NodePtr &node) { |
|
|
|
GELOGD("MergePass running"); |
|
|
|
if (node == nullptr) { |
|
|
@@ -58,11 +53,6 @@ Status MergePass::Run(NodePtr &node) { |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (OptimizeEmptyTensorInput(node) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
|
switch (in_data_nodes.size()) { |
|
|
|
case 0: { |
|
|
@@ -212,30 +202,4 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) { |
|
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { |
|
|
|
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); |
|
|
|
if (peer_data_anchor == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if ((peer_data_anchor->GetOwnerNode() == nullptr) || |
|
|
|
(peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc(); |
|
|
|
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) { |
|
|
|
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", |
|
|
|
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), |
|
|
|
node->GetName().c_str(), in_data_anchor->GetIdx()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
GELOGD("Remove data edge %s:%d->%s:%d", |
|
|
|
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), |
|
|
|
node->GetName().c_str(), in_data_anchor->GetIdx()); |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace ge |