| @@ -88,6 +88,14 @@ Status VariableMemoryAssigner::AssignVarAttr2Nodes() { | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| Status VariableMemoryAssigner::AssignMemory2HasRefAttrNode() { | |||||
| Status result = ge::VarMemAssignUtil::AssignMemory2HasRefAttrNode(compute_graph_); | |||||
| if (result != ge::SUCCESS) { | |||||
| return result; | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| Status GraphMemoryAssigner::AssignMemory() { | Status GraphMemoryAssigner::AssignMemory() { | ||||
| ge::HybridMemAssignerPtr mem_assigner(new(std::nothrow) HybridMemAssigner(compute_graph_)); | ge::HybridMemAssignerPtr mem_assigner(new(std::nothrow) HybridMemAssigner(compute_graph_)); | ||||
| if (mem_assigner->Assign() != ge::SUCCESS) { | if (mem_assigner->Assign() != ge::SUCCESS) { | ||||
| @@ -135,6 +143,19 @@ ge::Status GraphMemoryAssigner::AssignVarAttr2Nodes() { | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| ge::Status GraphMemoryAssigner::AssignMemory2HasRefAttrNode() { | |||||
| auto variable_assigner = | |||||
| std::unique_ptr<ge::VariableMemoryAssigner>(new(std::nothrow) ge::VariableMemoryAssigner(compute_graph_)); | |||||
| if (variable_assigner == nullptr) { | |||||
| GELOGE(ge::FAILED, "Alloc VariableMemoryAssigner failed."); | |||||
| return ge::FAILED; | |||||
| } | |||||
| if (variable_assigner->AssignMemory2HasRefAttrNode() != ge::SUCCESS) { | |||||
| return ge::FAILED; | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| ge::Status CalculateTensorRealSizeAndOutSize(const ge::ConstGeTensorDescPtr &output_desc, | ge::Status CalculateTensorRealSizeAndOutSize(const ge::ConstGeTensorDescPtr &output_desc, | ||||
| int64_t dim_index, int64_t &output_mem_size, | int64_t dim_index, int64_t &output_mem_size, | ||||
| int64_t &batch_dim_num, int64_t &out_size) { | int64_t &batch_dim_num, int64_t &out_size) { | ||||
| @@ -63,6 +63,8 @@ class VariableMemoryAssigner { | |||||
| /// | /// | ||||
| ge::Status AssignVarAttr2Nodes(); | ge::Status AssignVarAttr2Nodes(); | ||||
| ge::Status AssignMemory2HasRefAttrNode(); | |||||
| private: | private: | ||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| }; | }; | ||||
| @@ -99,6 +101,8 @@ class GraphMemoryAssigner { | |||||
| /// | /// | ||||
| ge::Status AssignVarAttr2Nodes(); | ge::Status AssignVarAttr2Nodes(); | ||||
| ge::Status AssignMemory2HasRefAttrNode(); | |||||
| ge::Status ReAssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_type_to_offset); | ge::Status ReAssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_type_to_offset); | ||||
| ge::Status AssignZeroCopyMemory(map<int64_t, size_t> &mem_offset, size_t &zero_mem_copy_size); | ge::Status AssignZeroCopyMemory(map<int64_t, size_t> &mem_offset, size_t &zero_mem_copy_size); | ||||
| @@ -40,6 +40,11 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, map<int64_t, size_t> &me | |||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| if (graph_mem_assigner.AssignMemory2HasRefAttrNode() != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Assign memory to node which has ref attr failed!"); | |||||
| return ge::FAILED; | |||||
| } | |||||
| // Assign memory for reference | // Assign memory for reference | ||||
| if (graph_mem_assigner.AssignReferenceMemory() != ge::SUCCESS) { | if (graph_mem_assigner.AssignReferenceMemory() != ge::SUCCESS) { | ||||
| GELOGE(ge::FAILED, "Assign reference memory failed!"); | GELOGE(ge::FAILED, "Assign reference memory failed!"); | ||||
| @@ -33,10 +33,7 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| Status VarMemAssignUtil::AssignVarMemory(ge::ComputeGraphPtr &compute_graph) { | Status VarMemAssignUtil::AssignVarMemory(ge::ComputeGraphPtr &compute_graph) { | ||||
| GE_CHK_STATUS_RET(AssignMemory2VariableNode(compute_graph)); | |||||
| GE_CHK_STATUS_RET(AssignMemory2HasRefAttrNode(compute_graph)); | |||||
| return SUCCESS; | |||||
| return AssignMemory2VariableNode(compute_graph); | |||||
| } | } | ||||
| Status VarMemAssignUtil::AssignConstantOpMemory(ge::ComputeGraphPtr &compute_graph) { | Status VarMemAssignUtil::AssignConstantOpMemory(ge::ComputeGraphPtr &compute_graph) { | ||||