|
|
@@ -69,6 +69,10 @@ int64_t GetSymbolOutputOffset(const std::map<std::string, std::string> &anchor_t |
|
|
|
} |
|
|
|
return ge::kInvalidOffset; |
|
|
|
} |
|
|
|
|
|
|
|
bool isVariableMemoryNode(const ge::NodePtr &node) { |
|
|
|
return (node->GetType() == ge::VARIABLE) || (node->GetType() == ge::CONSTANTOP); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
namespace ge { |
|
|
|
Status VariableMemoryAssigner::Assign() { |
|
|
@@ -447,22 +451,31 @@ bool IsContinuousInputConflict(const ge::NodePtr &node, const OpDescPtr &peer_op |
|
|
|
/// op1 -> node -> op2 |
|
|
|
/// return true when node is ref from input, and op1 or op2 is reuse input from output |
|
|
|
bool GraphMemoryAssigner::IsRefFromInputOpCascade(const NodePtr &node) { |
|
|
|
bool ref_from_input = false; |
|
|
|
std::unordered_set<int32_t> ref_input_index; |
|
|
|
int32_t reuse_in_index = -1; |
|
|
|
for (const auto &out_anchor : node->GetAllOutDataAnchors()) { |
|
|
|
ref_from_input = GraphUtils::IsRefFromInput(out_anchor, reuse_in_index); |
|
|
|
if (ref_from_input) { |
|
|
|
bool reuse_input = GraphUtils::IsRefFromInput(out_anchor, reuse_in_index); |
|
|
|
if (reuse_input) { |
|
|
|
GELOGD("IsRefFromInputOpCascade: cur node:%s:%d is ref", node->GetName().c_str(), reuse_in_index); |
|
|
|
break; |
|
|
|
ref_input_index.insert(reuse_in_index); |
|
|
|
} |
|
|
|
} |
|
|
|
bool ref_from_input = !ref_input_index.empty(); |
|
|
|
if (!ref_from_input) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &in_anchor : node->GetAllInDataAnchors()) { |
|
|
|
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); |
|
|
|
auto in_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
if (isVariableMemoryNode(in_node) && (ref_input_index.count(in_anchor->GetIdx()) > 0)) { |
|
|
|
GELOGD("Reuse variable memory, input node:%s, type:%s.", in_node->GetName().c_str(), in_node->GetType().c_str()); |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (ref_from_input && GraphUtils::IsRefFromInput(peer_out_anchor, reuse_in_index)) { |
|
|
|
GELOGD("IsRefFromInputOpCascade: in node[%s] is ref, reuse index is:%d", |
|
|
|
peer_out_anchor->GetOwnerNode()->GetName().c_str(), reuse_in_index); |
|
|
|
in_node->GetName().c_str(), reuse_in_index); |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
@@ -500,6 +513,11 @@ Status GraphMemoryAssigner::UpdateRefOpOffsetReverse(const NodePtr &node) { |
|
|
|
GE_CHECK_NOTNULL(peer_out_anchor); |
|
|
|
auto peer_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL(peer_node); |
|
|
|
if (isVariableMemoryNode(peer_node)) { |
|
|
|
GELOGW("Peer node to update is %s, skip it. Node name:%s.", |
|
|
|
peer_node->GetType().c_str(), peer_node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto peer_op_desc = peer_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(peer_op_desc); |
|
|
|
vector<int64_t> peer_output_list = peer_op_desc->GetOutputOffset(); |
|
|
|