From 2561188d96a5111f8b8df3e9b7aec9ed85e5130e Mon Sep 17 00:00:00 2001 From: wangzhengjun Date: Tue, 30 Mar 2021 20:45:52 +0800 Subject: [PATCH] buffer pool memory allocator --- ge/CMakeLists.txt | 4 + ge/ge_inference.mk | 1 + ge/ge_runner.mk | 1 + ge/graph/build/memory/block_mem_assigner.cc | 11 +- ge/graph/build/memory/block_mem_assigner.h | 2 + .../build/memory/buffer_pool_mem_assigner.cc | 234 +++++ .../build/memory/buffer_pool_mem_assigner.h | 83 ++ ge/graph/build/memory/graph_mem_assigner.cc | 52 + ge/graph/build/memory/graph_mem_assigner.h | 2 + ge/graph/build/memory/module.mk | 1 + ge/graph/build/run_context.cc | 5 +- ge/graph/build/stream_allocator.cc | 224 +++- ge/graph/build/stream_allocator.h | 4 + ge/graph/common/omg_util.cc | 40 + ge/graph/common/omg_util.h | 21 + ge/graph/load/model_manager/davinci_model.cc | 10 +- ge/graph/manager/graph_manager.cc | 7 + ge/graph/passes/buffer_pool_memory_pass.cc | 574 ++++++++++ ge/graph/passes/buffer_pool_memory_pass.h | 136 +++ tests/depends/runtime/src/runtime_stub.cc | 5 + tests/ut/ge/CMakeLists.txt | 5 + .../buffer_pool_mem_assigner_unittest.cc | 607 +++++++++++ .../buffer_pool_memory_pass_unittest.cc | 591 +++++++++++ .../graph/utils/buffer_pool_graph_builder.cc | 978 ++++++++++++++++++ .../graph/utils/buffer_pool_graph_builder.h | 279 +++++ 25 files changed, 3868 insertions(+), 9 deletions(-) create mode 100644 ge/graph/build/memory/buffer_pool_mem_assigner.cc create mode 100644 ge/graph/build/memory/buffer_pool_mem_assigner.h create mode 100644 ge/graph/passes/buffer_pool_memory_pass.cc create mode 100644 ge/graph/passes/buffer_pool_memory_pass.h create mode 100644 tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc create mode 100644 tests/ut/ge/graph/passes/buffer_pool_memory_pass_unittest.cc create mode 100644 tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc create mode 100644 tests/ut/ge/graph/utils/buffer_pool_graph_builder.h diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index c92cbdca..87e89a38 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST "graph/passes/memcpy_addr_async_pass.cc" "graph/passes/parallel_group_pass.cc" "graph/passes/set_input_output_offset_pass.cc" + "graph/passes/buffer_pool_memory_pass.cc" "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" @@ -407,6 +408,7 @@ set(TRAIN_SRC_LIST "graph/build/memory/hybrid_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc" "graph/build/memory/var_mem_assign_util.cc" + "graph/build/memory/buffer_pool_mem_assigner.cc" ) set(INFER_SRC_LIST @@ -617,6 +619,7 @@ set(INFER_SRC_LIST "graph/passes/memcpy_addr_async_pass.cc" "graph/passes/set_input_output_offset_pass.cc" "graph/passes/parallel_group_pass.cc" + "graph/passes/buffer_pool_memory_pass.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" @@ -680,6 +683,7 @@ set(INFER_SRC_LIST "graph/build/memory/hybrid_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc" "graph/build/memory/var_mem_assign_util.cc" + "graph/build/memory/buffer_pool_mem_assigner.cc" ) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index 5d5e734c..f30ba22a 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -222,6 +222,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/hccl_group_pass.cc \ graph/passes/memcpy_addr_async_pass.cc \ graph/passes/set_input_output_offset_pass.cc \ + graph/passes/buffer_pool_memory_pass.cc \ OMG_DEVICE_SRC_FILES := $(OMG_HOST_SRC_FILES) diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 421d41e8..0efcf820 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -246,6 +246,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/end_of_sequence_add_control_pass.cc \ graph/passes/memcpy_addr_async_pass.cc \ graph/passes/set_input_output_offset_pass.cc \ + graph/passes/buffer_pool_memory_pass.cc \ graph/preprocess/graph_preprocess.cc \ graph/preprocess/insert_op/ge_aipp_op.cc \ graph/preprocess/insert_op/util_insert_aipp_op.cc \ diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc index 6fbb9826..ad5ed1a2 100755 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -1655,6 +1655,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector bool is_atomic = false; // If GetBool fail, is_atomic is false. (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); + bool is_buffer_pool_mem_supported = (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID)) && + (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) && (!root_unknown_shape_flag_); // Allocate memory for the current node and release node memory of the same size in the workspace GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", for (auto iter = stream_workspace_blocks_.begin(); iter != stream_workspace_blocks_.end(); @@ -1694,7 +1696,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector GE_IF_BOOL_EXEC(!no_need_assign_memory, no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); } - no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node)); + no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node) || is_buffer_pool_mem_supported); if (no_need_assign_memory) { zero_memory_list_.emplace_back(node, kOutput, i, false); continue; @@ -1740,6 +1742,13 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM); GE_IF_BOOL_EXEC(op_no_reuse_mem != nullptr, op_no_reuse_mem_str = string(op_no_reuse_mem); CheckAndGetOpReuseEnv(op_no_reuse_mem_str, op_no_reuse_mem_vec_, op_reuse_env_valid_);); + auto root_graph = GraphUtils::FindRootGraph(compute_graph_); + if (root_graph == nullptr) { + GELOGE(INTERNAL_ERROR, "[Check][RootGraph]Root graph is nullptr, graph:%s.", compute_graph_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Root graph is nullptr, graph:%s.", compute_graph_->GetName().c_str()); + return; + } + root_unknown_shape_flag_ = root_graph->GetGraphUnknownFlag(); for (NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); diff --git a/ge/graph/build/memory/block_mem_assigner.h b/ge/graph/build/memory/block_mem_assigner.h index 199a84f9..474db17c 100755 --- a/ge/graph/build/memory/block_mem_assigner.h +++ b/ge/graph/build/memory/block_mem_assigner.h @@ -494,6 +494,8 @@ class BlockMemAssigner : public MemAssigner { /// @ [stream2][nodeid] /// DependStreamLife total_node_depend_stream_life_; + + bool root_unknown_shape_flag_ = false; }; } // namespace ge #endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_ diff --git a/ge/graph/build/memory/buffer_pool_mem_assigner.cc b/ge/graph/build/memory/buffer_pool_mem_assigner.cc new file mode 100644 index 00000000..d66fe038 --- /dev/null +++ b/ge/graph/build/memory/buffer_pool_mem_assigner.cc @@ -0,0 +1,234 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/build/memory/buffer_pool_mem_assigner.h" +#include "graph/common/omg_util.h" +#include "graph/utils/tensor_utils.h" +#include "framework/common/util.h" +#include "graph/compute_graph.h" +#include "graph/debug/ge_attr_define.h" +#include "common/math/math_util.h" +#include "common/util/error_manager/error_manager.h" + +namespace ge { +namespace { +const size_t kBufferPoolNodeMemInfoLength = 2; +const uint32_t kBufferPoolNodeOutputSizeIndex = 0; +const uint32_t kBufferPoolNodeOutputOffsetIndex = 1; +} // namespace + +Status BufferPoolMemAssigner::Assign() { + if (compute_graph_ == nullptr) { + GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr"); + REPORT_INNER_ERROR("E19999", "Input graph is nullptr"); + return PARAM_INVALID; + } + Status ret = InitAssigner(compute_graph_); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Init][Assigner]Graph:%s.", compute_graph_->GetName().c_str()); + return FAILED; + } + ret = AssignOutput(); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Assign][Output]Graph:%s.", compute_graph_->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + +Status BufferPoolMemAssigner::GetOutputMemoryType(const NodePtr &node, size_t idx, int64_t &memory_type) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + memory_type = RT_MEMORY_HBM; + std::vector type_list; + bool has_mem_type = ge::AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list); + if (has_mem_type && (type_list.size() != node->GetOpDesc()->GetOutputsSize() || idx >= type_list.size())) { + GELOGE(PARAM_INVALID, "[Check][OutputParam]Output param invalid, output size:%zu, mem type size:%zu, index:%zu.", + node->GetOpDesc()->GetOutputsSize(), type_list.size(), idx); + REPORT_INNER_ERROR("E19999", "Output param invalid, output size:%zu, mem type size:%zu, index:%zu.", + node->GetOpDesc()->GetOutputsSize(), type_list.size(), idx); + return PARAM_INVALID; + } + memory_type = has_mem_type ? type_list[idx] : RT_MEMORY_HBM; + return SUCCESS; +} + +Status BufferPoolMemAssigner::InitAssigner(const ComputeGraphPtr &graph) { + for (const NodePtr &node : graph->GetAllNodes()) { + int64_t buffer_pool_id = 0; + int64_t buffer_pool_size = 0; + bool get_attr = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id); + get_attr = get_attr && (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size)); + if (get_attr) { + std::string batch_label; + (void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); + buffer_pool_nodes_[batch_label][buffer_pool_id].emplace_back(node); + auto iter = buffer_pool_size_[batch_label].find(buffer_pool_id); + if (iter == buffer_pool_size_[batch_label].end()) { + buffer_pool_size_[batch_label][buffer_pool_id] = buffer_pool_size; + } + Status ret = InitMemOffsetBase(node); + if (ret != SUCCESS) { + GELOGE(ret, "[Init][MemOffsetBase]Batch label:%s.", batch_label.c_str()); + REPORT_INNER_ERROR("E19999", "Failed to init offset base, batch label:%s.", batch_label.c_str()); + return ret; + } + } + } + + int64_t max_size = 0; + for (const auto &iter : buffer_pool_size_) { + std::string batch_label = iter.first; + int64_t batch_offset = mem_offset_base_; + for (const auto &buffer_pool : iter.second) { + int64_t buffer_pool_id = buffer_pool.first; + int64_t buffer_pool_size = buffer_pool.second; + buffer_pool_offset_base_[batch_label][buffer_pool_id] = batch_offset; + FMK_INT64_ADDCHECK(buffer_pool_size, kBufferPoolMemAlignSize); + AlignMemSize(buffer_pool_size, kBufferPoolMemAlignSize); + FMK_INT64_ADDCHECK(batch_offset, (buffer_pool_size + kBufferPoolMemAlignSize)); + batch_offset += (buffer_pool_size + kBufferPoolMemAlignSize); + } + int64_t batch_mem_size = batch_offset - mem_offset_base_; + GELOGI("[Init][Assigner]Get batch mem size, batch label:%s, mem size:%ld.", batch_label.c_str(), batch_mem_size); + if (max_size < batch_mem_size) { + max_size = batch_mem_size; + } + } + FMK_INT64_ADDCHECK(mem_offset_base_, max_size); + mem_offset_ = static_cast(mem_offset_base_ + max_size); + GELOGI("[Init][Assigner]Init buffer pool mem assigner successfully, " + "mem type:%ld, mem offset base:%ld, mem offset:%zu.", mem_type_, mem_offset_base_, mem_offset_); + return SUCCESS; +} + +Status BufferPoolMemAssigner::InitMemOffsetBase(const NodePtr &node) { + int64_t mem_type; + Status ret = GetOutputMemoryType(node, static_cast(kBufferPoolNodeOutIndex), mem_type); + if (ret != SUCCESS) { + GELOGE(ret, "[Get][MemType]Node:%s, index:%u.", node->GetName().c_str(), kBufferPoolNodeOutIndex); + REPORT_INNER_ERROR("E19999", "Failed to get output memory type, node:%s, index:%u.", + node->GetName().c_str(), kBufferPoolNodeOutIndex); + return ret; + } + if (mem_type_ != mem_type && init_offset_base_) { + GELOGE(PARAM_INVALID, "[Check][MemType]The memory type of all buffer pool nodes must be the same, node:%s, " + "required:%ld, actually: %ld", node->GetName().c_str(), mem_type_, mem_type); + REPORT_INNER_ERROR("E19999", "The memory type of all buffer pool nodes must be the same, node:%s, " + "required:%ld, actually: %ld", node->GetName().c_str(), mem_type_, mem_type); + return PARAM_INVALID; + } + if (!init_offset_base_) { + auto iter = mem_type_to_offset_.find(mem_type); + if (iter == mem_type_to_offset_.end()) { + GELOGE(PARAM_INVALID, "[Check][MemType]Memory type is not supported, node:%s, mem type:%ld.", + node->GetName().c_str(), mem_type); + REPORT_INNER_ERROR("E19999", "Memory type is not supported, node:%s, mem type:%ld.", + node->GetName().c_str(), mem_type); + return PARAM_INVALID; + } + mem_offset_base_ = static_cast(iter->second); + FMK_INT64_ADDCHECK(mem_offset_base_, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize)); + AlignMemSize(mem_offset_base_, kBufferPoolMemAlignSize); + // The HCOM nodes may access the previous 512 bytes. + mem_offset_base_ += kBufferPoolMemAlignSize; + mem_type_ = mem_type; + init_offset_base_ = true; + GELOGI("[Init][MemOffsetBase]Init offset base:%ld, memory type:%ld", mem_offset_base_, mem_type); + } + return SUCCESS; +} + +Status BufferPoolMemAssigner::AssignOutput() { + for (auto &batch_pool_nodes_map : buffer_pool_nodes_) { + std::string batch_label = batch_pool_nodes_map.first; + for (auto &pool_nodes_map : batch_pool_nodes_map.second) { + int64_t buffer_pool_id = pool_nodes_map.first; + auto iter_buffer_id_size = buffer_pool_size_[batch_label].find(buffer_pool_id); + if (iter_buffer_id_size == buffer_pool_size_[batch_label].end()) { + GELOGE(INTERNAL_ERROR, "[Get][BufferPoolSize]Pool id:%ld.", buffer_pool_id); + REPORT_INNER_ERROR("E19999", "Failed to get buffer pool size, pool id:%ld.", buffer_pool_id); + return INTERNAL_ERROR; + } + auto iter_buffer_id_offset = buffer_pool_offset_base_[batch_label].find(buffer_pool_id); + if (iter_buffer_id_offset == buffer_pool_offset_base_[batch_label].end()) { + GELOGE(INTERNAL_ERROR, "[Get][BufferPoolBaseOffset]Pool id:%ld.", buffer_pool_id); + REPORT_INNER_ERROR("E19999", "Failed to get buffer pool base offset, pool id:%ld.", buffer_pool_id); + return INTERNAL_ERROR; + } + int64_t buffer_pool_size = iter_buffer_id_size->second; + int64_t output_offset_base = iter_buffer_id_offset->second; + Status ret = AssignOutputInOneBufferPool(batch_label, output_offset_base, pool_nodes_map.second); + if (ret != SUCCESS) { + GELOGE(ret, "[Assign][OneBufferPool]Batch label:%s, pool id:%ld, pool size:%ld, offset base:%ld.", + batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base); + REPORT_INNER_ERROR("E19999", "Failed to assign output memory, batch label:%s, " + "pool id:%ld, pool size:%ld, offset base:%ld.", + batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base); + return ret; + } + GELOGI("[Assign][Output]Assign output successfully, batch label:%s, pool id:%ld, pool size:%ld, offset base:%ld.", + batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base); + } + } + return SUCCESS; +} + +Status BufferPoolMemAssigner::AssignOutputInOneBufferPool(const std::string &batch_label, + int64_t output_offset_base, + const std::vector &buffer_pool_nodes) { + for (const NodePtr &node : buffer_pool_nodes) { + int64_t output_size = 0; + Status ret = GetMemorySize(node, output_size); + if (ret != SUCCESS) { + GELOGE(ret, "[Get][MemSize]Node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str()); + return ret; + } + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + vector memory_size_and_offset; + bool get_attr = AttrUtils::GetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, memory_size_and_offset); + if (!get_attr || memory_size_and_offset.size() != kBufferPoolNodeMemInfoLength) { + GELOGE(PARAM_INVALID, "[Get][Attr]Node:%s, mem info size:%zu, required size:%zu.", + node->GetName().c_str(), memory_size_and_offset.size(), kBufferPoolNodeMemInfoLength); + REPORT_INNER_ERROR("E19999", "Failed to get pool node memory info, node:%s, info size:%zu, required size:%zu.", + node->GetName().c_str(), memory_size_and_offset.size(), kBufferPoolNodeMemInfoLength); + return PARAM_INVALID; + } + if (output_size != memory_size_and_offset[kBufferPoolNodeOutputSizeIndex]) { + GELOGE(PARAM_INVALID, "[Check][MemSize]Something wrong with memory size, pre size:%ld, curr size:%ld, node:%s.", + memory_size_and_offset[kBufferPoolNodeOutputSizeIndex], output_size, node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Something wrong with memory size, pre size:%ld, curr size:%ld, node:%s.", + memory_size_and_offset[kBufferPoolNodeOutputSizeIndex], output_size, node->GetName().c_str()); + return PARAM_INVALID; + } + + int64_t logical_offset = memory_size_and_offset[kBufferPoolNodeOutputOffsetIndex]; + vector output_list = {(output_offset_base + logical_offset)}; + op_desc->SetOutputOffset(output_list); + // log for IMAS tools + GELOGI("[IMAS]Set %s name[%s] optype[%s] %s[%u] offset to [%ld] streamid[%ld] memtype[%ld] " + "size[%zu] realsize[%zu] noalignsize[%zu] life time begin[%d] life time end[%d] " + "child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", + compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), + "output", kBufferPoolNodeOutIndex, output_list[kBufferPoolNodeOutIndex], op_desc->GetStreamId(), mem_type_, + static_cast(output_size), static_cast(output_size), static_cast(output_size), + 0, 0, 0, 0, 0, 0, 0, 0, batch_label.c_str()); + } + return SUCCESS; +} + +} // namespace ge diff --git a/ge/graph/build/memory/buffer_pool_mem_assigner.h b/ge/graph/build/memory/buffer_pool_mem_assigner.h new file mode 100644 index 00000000..6caed031 --- /dev/null +++ b/ge/graph/build/memory/buffer_pool_mem_assigner.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_ +#define GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_ + +#include +#include +#include +#include "graph/build/memory/mem_assigner.h" +#include "runtime/mem.h" + +namespace ge { +class BufferPoolMemAssigner : public MemAssigner { + public: + BufferPoolMemAssigner(ComputeGraphPtr compute_graph, const std::map &mem_type_to_offset) + : MemAssigner(), compute_graph_(compute_graph), + mem_type_(0), + mem_offset_(0), + mem_offset_base_(0), + init_offset_base_(false), + mem_type_to_offset_(mem_type_to_offset) {} + + BufferPoolMemAssigner(const BufferPoolMemAssigner &) = delete; + + BufferPoolMemAssigner &operator=(const BufferPoolMemAssigner &) = delete; + + ~BufferPoolMemAssigner() override = default; + + Status Assign() override; + + size_t GetMemOffset() const { return mem_offset_; } + + int64_t GetMemType() const { return mem_type_; } + + private: + static Status GetOutputMemoryType(const NodePtr &node, size_t idx, int64_t &memory_type); + + Status InitAssigner(const ComputeGraphPtr &graph); + + Status InitMemOffsetBase(const NodePtr &node); + + Status AssignOutput(); + + Status AssignOutputInOneBufferPool(const std::string &batch_label, + int64_t output_offset_base, + const std::vector &buffer_pool_nodes); + + ComputeGraphPtr compute_graph_; + + int64_t mem_type_; + + size_t mem_offset_; + + int64_t mem_offset_base_; + + bool init_offset_base_; + + std::map mem_type_to_offset_; + + // Use map to ensure that each visit is in the order of pool id + std::unordered_map>> buffer_pool_nodes_; + + // Use map to ensure that each visit is in the order of pool id + std::unordered_map> buffer_pool_size_; + + std::unordered_map> buffer_pool_offset_base_; +}; +} // namespace ge +#endif // GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_ diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc index 44ba780d..9b53403a 100755 --- a/ge/graph/build/memory/graph_mem_assigner.cc +++ b/ge/graph/build/memory/graph_mem_assigner.cc @@ -30,6 +30,7 @@ #include "graph/manager/graph_var_manager.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/build/memory/buffer_pool_mem_assigner.h" namespace { const int kAllInputAddrIsAtomic = -1; @@ -231,6 +232,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map bool { + for (NodePtr &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + bool has_attrs = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE); + if (has_attrs) { + return true; + } + } + return false; + }; + auto root_graph = GraphUtils::FindRootGraph(compute_graph_); + GE_CHECK_NOTNULL(root_graph); + if (root_graph->GetGraphUnknownFlag()) { + GELOGI("[Check][Enable]Unknown root graph does not support buffer pool memory, graph:%s.", + compute_graph_->GetName().c_str()); + return SUCCESS; + } + if (!is_buffer_pool_mem_enable(compute_graph_)) { + GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", compute_graph_->GetName().c_str()); + return SUCCESS; + } + map mem_type_to_offset; + for (const auto &pair : memory_offset_) { + mem_type_to_offset[pair.first] = pair.second.mem_offset_; + } + BufferPoolMemAssigner buffer_pool_mem_assigner(compute_graph_, mem_type_to_offset); + Status status = buffer_pool_mem_assigner.Assign(); + if (status != SUCCESS) { + GELOGE(status, "[Assign][BufferPoolMem]Graph:%s.", compute_graph_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to assign buffer pool memory, graph:%s.", compute_graph_->GetName().c_str()); + return status; + } + int64_t mem_type = buffer_pool_mem_assigner.GetMemType(); + auto iter = memory_offset_.find(mem_type); + if (iter == memory_offset_.end()) { + GELOGE(FAILED, "[Check][MemType]Memory type is not supported, graph:%s, mem type:%ld.", + compute_graph_->GetName().c_str(), mem_type); + REPORT_INNER_ERROR("E19999", "Memory type is not supported, graph:%s, mem type:%ld.", + compute_graph_->GetName().c_str(), mem_type); + return FAILED; + } + iter->second.mem_offset_ = buffer_pool_mem_assigner.GetMemOffset(); + GELOGI("[Assign][BufferPoolMem]Assign buffer pool memory successfully, graph:%s, mem type:%ld, mem offset:%zu.", + compute_graph_->GetName().c_str(), mem_type, buffer_pool_mem_assigner.GetMemOffset()); + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/build/memory/graph_mem_assigner.h b/ge/graph/build/memory/graph_mem_assigner.h index 756781fe..92e599b8 100755 --- a/ge/graph/build/memory/graph_mem_assigner.h +++ b/ge/graph/build/memory/graph_mem_assigner.h @@ -188,6 +188,8 @@ class GraphMemoryAssigner { void PrintMemoryOffset(); + Status AssignBufferPoolMemory(); + MemoryOffsetMap memory_offset_; ge::ComputeGraphPtr compute_graph_; HybridMemAssignerPtr mem_assigner_; diff --git a/ge/graph/build/memory/module.mk b/ge/graph/build/memory/module.mk index 73617794..232c2fed 100755 --- a/ge/graph/build/memory/module.mk +++ b/ge/graph/build/memory/module.mk @@ -8,6 +8,7 @@ local_lib_src_files := memory_assigner.cc \ hybrid_mem_assigner.cc \ max_block_mem_assigner.cc \ var_mem_assign_util.cc \ + buffer_pool_mem_assigner.cc \ local_lib_inc_path := ${LOCAL_PATH} \ ${TOPDIR}inc \ diff --git a/ge/graph/build/run_context.cc b/ge/graph/build/run_context.cc index 100d5aee..c5fdfec1 100644 --- a/ge/graph/build/run_context.cc +++ b/ge/graph/build/run_context.cc @@ -18,6 +18,7 @@ #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" +#include "graph/common/omg_util.h" namespace ge { RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); } @@ -88,9 +89,11 @@ Status RunContextUtil::CreateRtModelResources(uint32_t stream_num, uint32_t even } // Create rt event + uint32_t create_flag = static_cast((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG : + RT_EVENT_DEFAULT); for (uint32_t i = 0; i < event_num; ++i) { rtEvent_t event = nullptr; - rt_ret = rtEventCreate(&event); + rt_ret = rtEventCreateWithFlag(&event, create_flag); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "call rtEventCreate fail, ret:%d, index:%u, when %s", static_cast(rt_ret), i, __FUNCTION__); diff --git a/ge/graph/build/stream_allocator.cc b/ge/graph/build/stream_allocator.cc index b1df0f2c..e1d1f937 100644 --- a/ge/graph/build/stream_allocator.cc +++ b/ge/graph/build/stream_allocator.cc @@ -27,6 +27,8 @@ #include "graph/ge_context.h" #include "graph/utils/graph_utils.h" #include "init/gelib.h" +#include "common/string_util.h" +#include "common/util/error_manager/error_manager.h" using std::map; using std::set; @@ -38,6 +40,13 @@ const int64_t kTaskNumPerNormalNode = 3; const int64_t kTaskNumPerHcclNode = 245; const char *const kTrueStr = "true"; const char *const kFalseStr = "false"; +const size_t kEventMultiplexingItemCount = 3; +const size_t kKeyWordIndex = 0; +const size_t kNodeNameIndex = 1; +const size_t kEventIdIndex = 2; +const char *const kSend = "SendTo"; +const char *const kRecv = "RecvFrom"; +const char kDelim = ';'; inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &continuous_stream_label) { if (ge::AttrUtils::GetStr(op_desc, ge::ATTR_NAME_CONTINUOUS_STREAM_LABEL, continuous_stream_label)) { @@ -52,6 +61,97 @@ bool IsHcclOp(const string &op_type) { ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); return hccl_op_types.find(op_type) != hccl_op_types.end(); } + +ge::Status ParseNodeEventMultiplexing(const ge::NodePtr &node, + const std::vector &raw_event_multiplexing, + std::unordered_map>> &node_to_send, + std::unordered_map>> &node_to_recv) { + GE_CHECK_NOTNULL(node); + for (const auto &str : raw_event_multiplexing) { + std::vector ele = ge::StringUtils::Split(str, kDelim); + if (ele.size() != kEventMultiplexingItemCount) { + GELOGE(ge::PARAM_INVALID, "[Check][RawMultiplexing]Size error, node:%s, require size:%zu, actually:%zu.", + node->GetName().c_str(), kEventMultiplexingItemCount, ele.size()); + REPORT_INNER_ERROR("E19999", "Raw event multiplexing is invalid, node:%s, require size:%zu, actually:%zu.", + node->GetName().c_str(), kEventMultiplexingItemCount, ele.size()); + return ge::PARAM_INVALID; + } + int value; + try { + value = std::stoi(ele[kEventIdIndex]); + } catch (std::invalid_argument &) { + GELOGE(ge::PARAM_INVALID, "[Throw][Exception]Event id is invalid, node:%s, raw:%s.", + node->GetName().c_str(), ele[kEventIdIndex].c_str()); + REPORT_INNER_ERROR("E19999", "Event id is invalid, node:%s, raw:%s.", + node->GetName().c_str(), ele[kEventIdIndex].c_str()); + return ge::PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(ge::PARAM_INVALID, "[Throw][Exception]Event id is out of range, node:%s, raw:%s.", + node->GetName().c_str(), ele[kEventIdIndex].c_str()); + REPORT_INNER_ERROR("E19999", "Event id is out of range, node:%s, raw:%s.", + node->GetName().c_str(), ele[kEventIdIndex].c_str()); + return ge::PARAM_INVALID; + } + if (value < 0) { + GELOGE(ge::PARAM_INVALID, "[Check][EventId]Event id is out of range, node:%s, raw:%s, value:%d.", + node->GetName().c_str(), ele[kEventIdIndex].c_str(), value); + REPORT_INNER_ERROR("E19999", "Event id is out of range, node:%s, raw:%s, value:%d.", + node->GetName().c_str(), ele[kEventIdIndex].c_str(), value); + return ge::PARAM_INVALID; + } + if (ele[kKeyWordIndex] == kSend) { + node_to_send[node].emplace_back(std::make_pair(ele[kNodeNameIndex], static_cast(value))); + } else if (ele[kKeyWordIndex] == kRecv) { + node_to_recv[node].emplace_back(std::make_pair(ele[kNodeNameIndex], static_cast(value))); + } else { + GELOGE(ge::PARAM_INVALID, "[Check][KeyWord]Key word is not supported, node:%s, key:%s.", + node->GetName().c_str(), ele[kEventIdIndex].c_str()); + REPORT_INNER_ERROR("E19999", "Key word is not supported, node:%s, key:%s.", + node->GetName().c_str(), ele[kEventIdIndex].c_str()); + return ge::PARAM_INVALID; + } + } + return ge::SUCCESS; +} + +ge::Status ParseAllNodeEventMultiplexing(const ge::ComputeGraphPtr &graph, + std::unordered_map &name_to_node_map, + std::unordered_map>> &node_to_send, + std::unordered_map>> &node_to_recv) { + for (const auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { + ge::OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + name_to_node_map.insert({node->GetName(), node}); + std::vector raw_event_multiplexing; + if (!(op_desc->HasAttr(ge::ATTR_NAME_EVENT_MULTIPLEXING))) { + continue; + } + bool get_attr = ge::AttrUtils::GetListStr(op_desc, ge::ATTR_NAME_EVENT_MULTIPLEXING, raw_event_multiplexing); + if (!get_attr) { + GELOGE(ge::PARAM_INVALID, "[Get][Attr]Node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to get raw event multiplexing, node:%s.", node->GetName().c_str()); + return ge::PARAM_INVALID; + } + auto parse_ret = ParseNodeEventMultiplexing(node, raw_event_multiplexing, node_to_send, node_to_recv); + if (parse_ret != ge::SUCCESS) { + GELOGE(parse_ret, "[Parse][Eventmultiplexing]Node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to parse node event multiplexing, node:%s.", node->GetName().c_str()); + return parse_ret; + } + } + return ge::SUCCESS; +} + +std::vector GetIntersection(std::vector &a, std::vector &b) { + std::unordered_set ele_of_a(a.begin(), a.end()); + std::vector res; + for (auto &ele : b) { + if (ele_of_a.count(ele) > 0) { + res.emplace_back(ele); + } + } + return res; +} } // namespace namespace ge { @@ -150,6 +250,12 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu return status; } + status = RefreshEventsWithReuse(); + if (status != SUCCESS) { + GELOGE(status, "[Refresh][Events]RefreshEventsWithReuse failed!"); + return status; + } + status = InsertSyncEventNodes(); if (status != SUCCESS) { GELOGE(status, "InsertSyncEventNode failed!"); @@ -1161,6 +1267,94 @@ Status StreamAllocator::CheckStreamActived() const { return SUCCESS; } +Status StreamAllocator::ReuseEvent(bool send_to, + const std::unordered_map &name_to_node_map, + const std::unordered_map>> &node_to_event_id) { + for (const auto &node_event_id : node_to_event_id) { + ge::NodePtr curr_node = node_event_id.first; + NodePtr send_node = send_to ? curr_node : nullptr; + NodePtr recv_node = send_to ? nullptr : curr_node; + for (const auto &event_pair : node_event_id.second) { + auto peer_node_iter = name_to_node_map.find(event_pair.first); + if (peer_node_iter == name_to_node_map.end()) { + GELOGE(PARAM_INVALID, "[Get][Node]Name:%s.", event_pair.first.c_str()); + REPORT_INNER_ERROR("E19999", "Failed to find node, name:%s.", event_pair.first.c_str()); + return PARAM_INVALID; + } + recv_node = send_to ? peer_node_iter->second : recv_node; + send_node = send_to ? send_node : peer_node_iter->second; + GE_CHECK_NOTNULL(send_node); + GE_CHECK_NOTNULL(recv_node); + auto event_id = GetIntersection(node_to_send_events_[send_node], node_to_recv_events_[recv_node]); + uint32_t new_event = event_pair.second + event_num_; + if (event_id.empty()) { + GELOGI("[Check][Optimized]Send:%s, recv:%s.", send_node->GetName().c_str(), recv_node->GetName().c_str()); + continue; + } else if (event_id.size() != 1) { + GELOGW("[Check][Event]More than one event are found between %s and %s, event num:%zu.", + send_node->GetName().c_str(), recv_node->GetName().c_str(), event_id.size()); + } + uint32_t old_event = event_id[0]; + auto reuse_event_id = [] (vector &event_list, uint32_t old_event, uint32_t new_event) -> void { + event_list.erase(std::remove(event_list.begin(), event_list.end(), old_event), event_list.end()); + event_list.push_back(new_event); + return; + }; + reuse_event_id(node_to_send_events_[send_node], old_event, new_event); + reuse_event_id(node_to_recv_events_[recv_node], old_event, new_event); + GELOGI("[Reuse][Event]Replace event successfully, send node:%s, recv node:%s, old id:%u, new id:%u.", + send_node->GetName().c_str(), recv_node->GetName().c_str(), old_event, new_event); + } + } + return ge::SUCCESS; +} + +// Refresh events to reuse events +Status StreamAllocator::RefreshEventsWithReuse() { + GELOGI("[Refresh][Events]Refresh events with reuse, stream num:%ld, original event num:%u.", stream_num_, event_num_); + if (event_num_ <= kEventReuseThreshold) { + GELOGI("[Check][ReuseThreshold]Event used num is %u, less than %u, skip reuse.", + event_num_, kEventReuseThreshold); + return SUCCESS; + } + std::unordered_map name_to_node_map; + std::unordered_map>> node_to_send; + std::unordered_map>> node_to_recv; + Status ret = ParseAllNodeEventMultiplexing(whole_graph_, name_to_node_map, node_to_send, node_to_recv); + if (ret != SUCCESS) { + GELOGE(ret, "[Parse][AllNodeEventMultiplexing]Graph:%s.", whole_graph_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to parse all node event multiplexing, graph:%s.", + whole_graph_->GetName().c_str()); + return ret; + } + if (node_to_send.empty() && node_to_recv.empty()) { + return SUCCESS; + } + + ret = ReuseEvent(true, name_to_node_map, node_to_send); + if (ret != SUCCESS) { + GELOGE(ret, "[Reuse][Event]Phase:Send, graph:%s.", whole_graph_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to reuse event, phase:Send, graph:%s.", whole_graph_->GetName().c_str()); + return ret; + } + + ret = ReuseEvent(false, name_to_node_map, node_to_recv); + if (ret != SUCCESS) { + GELOGE(ret, "[Reuse][Event]Phase:Recv, graph:%s.", whole_graph_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to reuse event, phase:Recv, graph:%s.", whole_graph_->GetName().c_str()); + return ret; + } + + Status status = RefreshContinuousEvents(); + if (status != SUCCESS) { + GELOGE(status, "[Refresh][ContinuousEvents]Graph:%s.", whole_graph_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to refresh continuous events, graph:%s.", whole_graph_->GetName().c_str()); + return status; + } + GELOGI("[Refresh][Events]RefreshEventsWithReuse successfully, event num:%u.", event_num_); + return SUCCESS; +} + // Refresh events to continuous events Status StreamAllocator::RefreshContinuousEvents() { // Establish a mapping relationship from old to new event id @@ -1168,8 +1362,10 @@ Status StreamAllocator::RefreshContinuousEvents() { uint32_t new_event_id = 0; for (const auto &one_pair : node_to_send_events_) { for (const auto &event_id : one_pair.second) { - old_to_new_events[event_id] = new_event_id; - new_event_id++; + if (old_to_new_events.find(event_id) == old_to_new_events.end()) { + old_to_new_events[event_id] = new_event_id; + new_event_id++; + } } } @@ -1208,6 +1404,7 @@ Status StreamAllocator::RefreshContinuousEvents() { // Insert the real send/recv node in the graph Status StreamAllocator::InsertSyncEventNodes() { + unordered_map sync_event_name; for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { // Add the node corresponding to the recv event vector recv_event_id_list; @@ -1217,6 +1414,13 @@ Status StreamAllocator::InsertSyncEventNodes() { GE_CHECK_NOTNULL(node->GetOutControlAnchor()); for (auto &event_id : recv_event_id_list) { string recv_node_name = whole_graph_->GetName() + "_Recv_" + to_string(event_id); + auto iter = sync_event_name.find(recv_node_name); + if (iter == sync_event_name.end()) { + sync_event_name[recv_node_name] = 1; + } else { + recv_node_name = recv_node_name + "_Reuse_" + to_string(iter->second); + ++(iter->second); + } OpDescPtr op_desc_ptr = MakeShared(recv_node_name, RECV); GE_CHECK_NOTNULL(op_desc_ptr); @@ -1251,6 +1455,13 @@ Status StreamAllocator::InsertSyncEventNodes() { for (auto &event_id : send_event_id_list) { string send_node_name = whole_graph_->GetName() + "_Send_" + to_string(event_id); + auto iter = sync_event_name.find(send_node_name); + if (iter == sync_event_name.end()) { + sync_event_name[send_node_name] = 1; + } else { + send_node_name = send_node_name + "_Reuse_" + to_string(iter->second); + ++(iter->second); + } OpDescPtr op_desc_ptr = MakeShared(send_node_name, SEND); GE_CHECK_NOTNULL(op_desc_ptr); @@ -1300,12 +1511,16 @@ void StreamAllocator::DumpEvents() { GELOGD("After RefreshRealStream: stream %ld.", stream_id); for (const auto &node : one_pair.second) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + continue; + } string send_event_str; for (const auto &send_event_id : node_to_send_events_[node]) { send_event_str += " " + to_string(send_event_id); } if (!send_event_str.empty()) { - GELOGI("node: %s, send events: %s", node->GetName().c_str(), send_event_str.c_str()); + GELOGI("node: %s, id: %ld, stream id :%ld, send events: %s.", node->GetName().c_str(), + node->GetOpDesc()->GetId(), node->GetOpDesc()->GetStreamId(), send_event_str.c_str()); } string recv_event_str; @@ -1313,7 +1528,8 @@ void StreamAllocator::DumpEvents() { recv_event_str += " " + to_string(recv_event_id); } if (!recv_event_str.empty()) { - GELOGI("node: %s, recv events: %s", node->GetName().c_str(), recv_event_str.c_str()); + GELOGI("node: %s, id: %ld, stream id :%ld, recv events: %s.", node->GetName().c_str(), + node->GetOpDesc()->GetId(), node->GetOpDesc()->GetStreamId(), recv_event_str.c_str()); } } } diff --git a/ge/graph/build/stream_allocator.h b/ge/graph/build/stream_allocator.h index dd82700d..44dcd673 100644 --- a/ge/graph/build/stream_allocator.h +++ b/ge/graph/build/stream_allocator.h @@ -71,6 +71,10 @@ class StreamAllocator { Status SetActiveStreamsForLoop(); Status CheckStreamActived() const; + Status ReuseEvent(bool send_to, + const std::unordered_map &name_to_node_map, + const std::unordered_map>> &node_to_event_id); + Status RefreshEventsWithReuse(); Status RefreshContinuousEvents(); Status InsertSyncEventNodes(); diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index b0d64a41..272707a5 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -21,6 +21,8 @@ #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "common/math/math_util.h" namespace ge { /// @@ -204,4 +206,42 @@ Status SetNextIteration(const ge::NodePtr &node, const std::string &next) { return SUCCESS; } + +/// +/// @brief Align the memory +/// @param [in/out] memory size +/// @param [in] alinment +/// @return void +/// +void AlignMemSize(int64_t &mem_size, int64_t align_size) { + if (mem_size <= 0) { + return; + } + mem_size = (mem_size + align_size - 1) / align_size * align_size; +} + +/// +/// @brief Get memory size from tensor desc +/// @param [in] node +/// @param [out] memory size +/// @return Status +/// +Status GetMemorySize(const NodePtr &node, int64_t &output_size) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + auto output_op_desc = node->GetOpDesc()->GetOutputDescPtr(kBufferPoolNodeOutIndex); + GE_CHECK_NOTNULL(output_op_desc); + int64_t size = 0; + auto ret = ge::TensorUtils::GetSize(*output_op_desc, size); + if (ret != ge::GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Get][Size]Node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + FMK_INT64_ADDCHECK(size, kBufferPoolMemAlignSize); + AlignMemSize(size, kBufferPoolMemAlignSize); + // The HCOM operator requires an additional 512 bytes before and after + FMK_INT64_ADDCHECK(size, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize)); + output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize; + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index 1f93c92b..561a12e0 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -27,6 +27,11 @@ #include "graph/node.h" namespace ge { +namespace { +const int64_t kBufferPoolMemAlignSize = 512; +const uint32_t kBufferPoolNodeOutIndex = 0; +const uint32_t kEventReuseThreshold = 65500; +} // namespace /// /// @brief get the Original Type of FrameworkOp /// @param [in] node @@ -96,6 +101,22 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node); /// @return Status /// Status SetNextIteration(const ge::NodePtr &node, const std::string &next); + +/// +/// @brief Align the memory +/// @param [in/out] memory size +/// @param [in] alinment +/// @return void +/// +void AlignMemSize(int64_t &mem_size, int64_t align_size); + +/// +/// @brief Get memory size from tensor desc +/// @param [in] node +/// @param [out] memory size +/// @return Status +/// +Status GetMemorySize(const NodePtr &node, int64_t &output_size); } // namespace ge #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 6b347a9d..c29ca475 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -60,6 +60,7 @@ #include "securec.h" #include "graph/common/local_context.h" #include "common/formats/utils/formats_trans_utils.h" +#include "graph/common/omg_util.h" // create std::thread, catch exceptions using try/catch #define CREATE_STD_THREAD(thread_id, func, args) \ @@ -664,9 +665,12 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size GELOGI("Logical stream index:%u, stream:%p, rtstream: %d.", i, stream, rt_stream_id); } - for (uint32_t i = 0; i < EventNum(); i++) { - rtEvent_t rt_event; - GE_CHK_RT_RET(rtEventCreate(&rt_event)); + uint32_t event_num = EventNum(); + uint32_t create_flag = static_cast((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG : + RT_EVENT_DEFAULT); + for (uint32_t i = 0; i < event_num; ++i) { + rtEvent_t rt_event = nullptr; + GE_CHK_RT_RET(rtEventCreateWithFlag(&rt_event, create_flag)); event_list_.push_back(rt_event); } diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index e122e28f..9ef04131 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -95,6 +95,7 @@ #include "graph/passes/memcpy_addr_async_pass.h" #include "graph/passes/hccl_continuous_memcpy_pass.h" #include "graph/passes/parallel_group_pass.h" +#include "graph/passes/buffer_pool_memory_pass.h" #include "graph/build/label_allocator.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" @@ -2528,6 +2529,12 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); + // Process offset and dependency for buffer pool memory assigner. + GE_TIMESTAMP_START(BufferPoolMemoryPass); + BufferPoolMemoryPass buffer_pool_mem_pass; + GE_CHK_STATUS_RET(buffer_pool_mem_pass.Run(compute_graph), "Failed to process for buffer pool allocator."); + GE_TIMESTAMP_END(BufferPoolMemoryPass, "BufferPoolMemoryPass::Run."); + // Handle parallel group . GE_TIMESTAMP_START(ParallelGroup); ParallelGroupPass parallel_group_pass; diff --git a/ge/graph/passes/buffer_pool_memory_pass.cc b/ge/graph/passes/buffer_pool_memory_pass.cc new file mode 100644 index 00000000..8a64da59 --- /dev/null +++ b/ge/graph/passes/buffer_pool_memory_pass.cc @@ -0,0 +1,574 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/buffer_pool_memory_pass.h" + +#include +#include +#include "graph/common/omg_util.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "common/math/math_util.h" + +namespace ge { +namespace { +const size_t kBufferPoolNodeInSize = 1; +const size_t kBufferPoolNodeOutSize = 1; +} // namespace + +Status BufferPoolMemoryPass::Run(ComputeGraphPtr graph) { + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr"); + REPORT_INNER_ERROR("E19999", "Input graph is nullptr"); + return PARAM_INVALID; + } + // The cache prefetching scheme is developed for very large models, which gets the weight data in advance + // and allocates it to a special memory pool. When the large model is dynamic shape, it need to go through + // the executor flow and is not allocated memory statically. This is another development point, so we will + // skip the dynamic shape model processing here. + if (graph->GetParentGraph() != nullptr || graph->GetGraphUnknownFlag()) { + return SUCCESS; + } + if (!IsBufferPoolMemEnable(graph)) { + GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", graph->GetName().c_str()); + return SUCCESS; + } + Status ret = graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str()); + return ret; + } + + ret = CopyOutForMultiUsedOutput(graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Copy][Output]Graph:%s.", graph->GetName().c_str()); + return FAILED; + } + + ret = GetBufferPoolAndPeerCalcNodes(graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Get][BufferPoolNode]Graph:%s.", graph->GetName().c_str()); + return FAILED; + } + if (calc_nodes_.empty()) { + GELOGE(FAILED, "[Check][BufferPoolNode]Graph:%s.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "All Buffer pool nodes are isolated nodes in graph:%s.", graph->GetName().c_str()); + return FAILED; + } + ret = AllocateAllBufferPoolSpace(); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Alloc][BufferPoolMem]Graph:%s.", graph->GetName().c_str()); + return FAILED; + } + + ret = SetResultOfMemoryAndEvent(); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Set][Result]Graph:%s.", graph->GetName().c_str()); + return FAILED; + } + ret = graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str()); + return ret; + } + return SUCCESS; +} + +void BufferPoolMemoryPass::ClearQueue(std::queue> &q) { + while (!q.empty()) { + q.pop(); + } +} + +Status BufferPoolMemoryPass::IsBufferPoolMemEnable(const ComputeGraphPtr &graph) { + for (NodePtr &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + if (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) { + return true; + } + } + return false; +} + +Status BufferPoolMemoryPass::CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size, + std::unordered_map &calc_total_size) { + auto iter = calc_total_size.find(pool_id); + if (iter == calc_total_size.end()) { + calc_total_size[pool_id] = total_size; + } else { + FMK_INT64_ADDCHECK(calc_total_size[pool_id], total_size); + calc_total_size[pool_id] += total_size; + } + if (calc_total_size[pool_id] > buffer_pool_size) { + GELOGE(INTERNAL_ERROR, "[Check][Size]The memory required at the same is greater than buffer pool size, " + "pool id:%ld, pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]); + REPORT_INNER_ERROR("E19999", "The memory required at the same is greater than buffer pool size, pool id:%ld," + " pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool ¬_change) { + auto pre_node_graph = pre_node->GetOwnerComputeGraph(); + auto curr_node_graph = curr_node->GetOwnerComputeGraph(); + std::string pre_node_stream_label; + (void) AttrUtils::GetStr(pre_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, pre_node_stream_label); + std::string curr_node_stream_label; + (void) AttrUtils::GetStr(curr_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, curr_node_stream_label); + not_change = true; + if ((pre_node_graph == curr_node_graph) && (pre_node_stream_label == pre_node_stream_label)) { + // Same subgraph, including simultaneously in the root graph. + auto ret = ge::GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), curr_node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Add][Edge]Src:%s, dst:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to add ctrl edge from %s to %s.", + pre_node->GetName().c_str(), curr_node->GetName().c_str()); + return INTERNAL_ERROR; + } + not_change = false; + } else if (pre_node_graph->GetParentGraph() == curr_node_graph->GetParentGraph() && + pre_node_graph->GetParentNode() != nullptr && curr_node_graph->GetParentNode() != nullptr) { + // Two nodes are located on different child graphs of different parent nodes. + auto pre_node_parent_op_desc = pre_node_graph->GetParentNode()->GetOpDesc(); + auto curr_node_parent_op_desc = curr_node_graph->GetParentNode()->GetOpDesc(); + GE_CHECK_NOTNULL(pre_node_parent_op_desc); + GE_CHECK_NOTNULL(curr_node_parent_op_desc); + // The parent node dependency is correct to ensure that the child node dependency, + // there is no need to add control edges. + if (pre_node_parent_op_desc->GetId() > curr_node_parent_op_desc->GetId()) { + GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.", + pre_node->GetName().c_str(), curr_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.", + pre_node->GetName().c_str(), curr_node->GetName().c_str()); + return INTERNAL_ERROR; + } + GELOGI("[Check][Dependency]The two nodes are located in sub graphs of different parent nodes and meet the " + "dependency relationship. pre:%s, curr:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str()); + } else { + GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.", + pre_node->GetName().c_str(), curr_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.", + pre_node->GetName().c_str(), curr_node->GetName().c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node) { + auto out_anchor = node->GetOutDataAnchor(kBufferPoolNodeOutIndex); + OpDescBuilder op_desc_builder(node->GetName() + "_memcpy_async", MEMCPYASYNC); + auto mem_copy_op = op_desc_builder.AddInput("x", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex)) + .AddOutput("y", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex)) + .Build(); + std::string batch_label; + bool get_attr = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, batch_label); + if (get_attr && !batch_label.empty()) { + (void) AttrUtils::SetStr(mem_copy_op, ATTR_NAME_STREAM_LABEL, batch_label); + } + auto peer_in_anchors = out_anchor->GetPeerInDataAnchors(); + std::vector in_anchors(peer_in_anchors.begin(), peer_in_anchors.end()); + if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(mem_copy_op)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "[Insert][Node] Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::CopyOutForMultiUsedOutput(ComputeGraphPtr &graph) { + bool changed = false; + for (NodePtr &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + bool use_buffer_pool = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE); + if (use_buffer_pool) { + if ((node->GetInDataNodes().size() == kBufferPoolNodeInSize) && + (node->GetOutDataNodes().size() == kBufferPoolNodeOutSize)) { + continue; + } else if ((node->GetAllInDataAnchors().size() == kBufferPoolNodeInSize) && + (node->GetAllOutDataAnchors().size() == kBufferPoolNodeOutSize)) { + // A prefetching output is used in multiple places. Copy one so that the prefetching node remains + // single input and single output. + if (InsertMemCpyNodeAfter(graph, node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Insert][MemCpy]Node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + changed = true; + GELOGI("[Insert][Node]Insert mem copy node after %s.", node->GetName().c_str()); + } else { + GELOGE(PARAM_INVALID, "[Check][InputOutput]Only support single input and single output, " + "node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Only support single input and single output, node:%s.", node->GetName().c_str()); + return PARAM_INVALID; + } + } + } + if (changed) { + Status ret = graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str()); + return ret; + } + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph) { + std::unordered_map>> unique_calc_nodes; + for (const NodePtr &node : graph->GetAllNodes()) { + auto in_data_nodes = node->GetInAllNodes(); + for (NodePtr &in_node : in_data_nodes) { + int64_t buffer_pool_id = 0; + int64_t buffer_pool_size = 0; + bool get_attr = AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id); + get_attr = get_attr && (AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size)); + if (get_attr) { + std::string batch_label; + (void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); + peer_buffer_node_item_[batch_label][node].emplace_back(BufferPoolNodeItem(in_node, 0, 0)); + buffer_node_to_calc_[batch_label][in_node] = node; + if (unique_calc_nodes[batch_label][buffer_pool_id].count(node) == 0) { + calc_nodes_[batch_label][buffer_pool_id].emplace_back(node); + unique_calc_nodes[batch_label][buffer_pool_id].insert(node); + } + GELOGI("[Get][BufferNode]Calc node:%s, pool node:%s.", node->GetName().c_str(), in_node->GetName().c_str()); + Status ret = SetBufferPoolSize(batch_label, buffer_pool_id, buffer_pool_size); + if (ret != SUCCESS) { + GELOGE(ret, "[Set][BufferPoolSize]Node:%s", in_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to set buffer pool size, something wrong with the info of node:%s", + in_node->GetName().c_str()); + return ret; + } + } + } + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size) { + auto iter = buffer_pool_size_[batch_label].find(id); + if (iter != buffer_pool_size_[batch_label].end() && iter->second != size) { + GELOGE(PARAM_INVALID, "[Check][BufferPoolSize]Get different size with the same id, " + "id:%ld, original size:%ld, this size:%ld.", id, iter->second, size); + REPORT_INNER_ERROR("E19999", "Get different size with the same id, " + "id:%ld, original size:%ld, this size:%ld.", id, iter->second, size); + return PARAM_INVALID; + } + buffer_pool_size_[batch_label][id] = size; + return SUCCESS; +} + +Status BufferPoolMemoryPass::AllocateAllBufferPoolSpace() { + for (const auto &iter : calc_nodes_) { + std::string batch_label = iter.first; + Status ret = AllocateSpaceInBatch(calc_nodes_[batch_label], + buffer_pool_size_[batch_label], + buffer_node_to_calc_[batch_label], + peer_buffer_node_item_[batch_label]); + if (ret != SUCCESS) { + GELOGE(ret, "[Alloc][InBatch]Batch_label:%s.", batch_label.c_str()); + REPORT_INNER_ERROR("E19999", "Failed to allocate space in batch, batch_label:%s.", batch_label.c_str()); + return ret; + } + GELOGI("[Alloc][InBatch]Alloc space in batch successfully, batch label:%s.", batch_label.c_str()); + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::AllocateSpaceInBatch( + const std::map> &calc_nodes, + const std::unordered_map &buffer_pool_size_map, + const std::unordered_map &buffer_node_to_calc, + std::unordered_map> &buffer_pool_nodes_item) { + for (const auto &calc_node_in_pool : calc_nodes) { + int64_t pool_id = calc_node_in_pool.first; + int64_t buffer_pool_size = buffer_pool_size_map.at(pool_id); + ClearQueue(mem_ctrl_event_); + ClearQueue(stream_ctrl_event_); + BufferPool buffer_pool(pool_id, buffer_pool_size, buffer_node_to_calc); + Status ret = AllocateSpaceInBufferPool(buffer_pool, + calc_node_in_pool.second, + buffer_pool_nodes_item); + if (ret != SUCCESS) { + GELOGE(ret, "[Alloc][InBufferPool]Pool id:%ld, pool size:%ld.", pool_id, buffer_pool_size); + REPORT_INNER_ERROR("E19999", "Failed to allocate space in buffer pool, id:%ld, pool size:%ld.", + pool_id, buffer_pool_size); + return ret; + } + GELOGI("[Alloc][InBufferPool]Alloc space in buffer pool successfully, pool id:%ld.", pool_id); + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::AllocateSpaceInBufferPool( + const BufferPool &buffer_pool, + const std::vector &calc_nodes_in_pool, + std::unordered_map> &buffer_pool_nodes_item) { + int64_t pool_id = buffer_pool.pool_id; + int64_t buffer_pool_size = buffer_pool.pool_size; + int64_t next_start = 0; + NodePtr pre_buffer_pool_node = nullptr; + std::queue node_mem_range_in_pool; + node_mem_range_in_pool.push(BufferPoolMemoryPass::BufferPoolNodeItem(nullptr, 0, buffer_pool_size)); + for (auto &calc_node : calc_nodes_in_pool) { + auto &peer_buffer_node_item = buffer_pool_nodes_item[calc_node]; + std::unordered_map calc_total_size; + size_t input_buffer_node_num = 0; + for (auto &node_item : peer_buffer_node_item) { + auto peer_buffer_node = node_item.node; + GE_CHECK_NOTNULL(peer_buffer_node); + int64_t total_size = 0; + ++input_buffer_node_num; + Status ret = GetMemorySize(peer_buffer_node, total_size); + if (ret != SUCCESS) { + GELOGE(ret, "[Get][MemSize]Node:%s, calc_node:%s.", + peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to get memory size, node:%s, calc_node:%s.", + peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); + return ret; + } + ret = CheckBufferPoolSize(total_size, pool_id, buffer_pool_size, calc_total_size); + if (ret != SUCCESS) { + GELOGE(ret, "[Check][BufferPoolSize]Capacity is not enough for all data, calc_node:%s.", + calc_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Capacity is not enough for all data, calc_node:%s.", + calc_node->GetName().c_str()); + return ret; + } + BufferPoolNodeItem buffer_pool_node_item(peer_buffer_node, calc_node, pre_buffer_pool_node, total_size, + 0, 0, (input_buffer_node_num == peer_buffer_node_item.size())); + ret = AllocateSpaceForBufferPoolNode(next_start, buffer_pool, buffer_pool_node_item, node_mem_range_in_pool); + if (ret != SUCCESS) { + GELOGE(ret, "[Alloc][ForNode]Pool node:%s, calc_node:%s.", + peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to allocate space for buffer pool node:%s, calc_node:%s.", + peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); + return ret; + } + pre_buffer_pool_node = peer_buffer_node; + } + } + return SUCCESS; +} + +Status BufferPoolMemoryPass::AllocateSpaceForBufferPoolNode(int64_t &next_start, + const BufferPool buffer_pool, + BufferPoolNodeItem &buffer_pool_node_item, + std::queue &node_mem_range_in_pool) { + // Get event id must be before FixTheTimingOfDependentNodes + uint32_t logic_event = logic_event_num_; + NodePtr buffer_node = buffer_pool_node_item.node; + NodePtr calc_node = buffer_pool_node_item.out_calc_node; + /// In the scenario where there are multiple PREFETCH operators in the inputs of the calculation operator, + /// the addition of events is optimized to only add events after the last PREFETCH operator. + /// w1 w2 w3 w4 w5 + /// | | | | | + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 xxx + /// \ / \ / \ / + /// \ / \ / \ / + /// \ / \ / \ / + /// node1 node2 node3 + /// | | | + /// | | | + /// --------------- other nodes ------------ + /// + /// The event id of the PREFETCH operator to the calculation operator needs to be generated before + /// FixTheTimingOfDependentNodes, because FixTheTimingOfDependentNodes may add a new id to stream_ctrl_event_, + /// and this id cannot be reused until the next PREFETCH operator in the sequence. + if (buffer_pool_node_item.is_last_input) { + logic_event = GenerateEventId(buffer_node->GetName(), stream_ctrl_event_); + node_event_multiplexing_[buffer_node].push_back(string("SendTo;" + calc_node->GetName() + + ";" + std::to_string(logic_event))); + mem_ctrl_event_.push(std::make_pair(calc_node->GetName(), logic_event)); + } + NodePtr dependent_calc_node = GetOffsetAndDependency(next_start, buffer_pool_node_item.total_size, + buffer_pool.pool_size, + buffer_pool.buffer_node_to_calc, + node_mem_range_in_pool); + if (dependent_calc_node != nullptr) { + Status ret = FixTheTimingOfDependentNodes(dependent_calc_node, buffer_node); + if (ret != SUCCESS) { + GELOGE(ret, "[Fix][Timing]Pool_id:%ld, pool node:%s, dependent node:%s.", + buffer_pool.pool_id, buffer_node->GetName().c_str(), dependent_calc_node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to fix timing, pool_id:%ld, pool node:%s, dependent node:%s.", + buffer_pool.pool_id, buffer_node->GetName().c_str(), + dependent_calc_node->GetName().c_str()); + return ret; + } + } + + buffer_pool_node_item.offset_start = next_start; + buffer_node_logical_offset_[buffer_node].push_back(buffer_pool_node_item.total_size); + buffer_node_logical_offset_[buffer_node].push_back(next_start); + FMK_INT64_ADDCHECK(next_start, buffer_pool_node_item.total_size); + next_start += buffer_pool_node_item.total_size; + buffer_pool_node_item.offset_end = next_start; + node_mem_range_in_pool.push(buffer_pool_node_item); + if (buffer_pool_node_item.pre_buffer_pool_node != nullptr) { + bool not_change = true; + auto ret = TryToFixNodeOrder(buffer_pool_node_item.pre_buffer_pool_node, buffer_node, not_change); + if (ret != SUCCESS) { + GELOGE(ret, "[Fix][BufferPoolNodeOrder]Pre node:%s, curr node:%s.", + buffer_pool_node_item.pre_buffer_pool_node->GetName().c_str(), buffer_node->GetName().c_str()); + return ret; + } + } + GELOGI("[Alloc][ForNode]Buffer pool node %s send to %s, offset start:%ld, send event id:%u.", + buffer_node->GetName().c_str(), calc_node->GetName().c_str(), + buffer_pool_node_item.offset_start, logic_event); + return SUCCESS; +} + +/// When generating the event ID, determine whether the name of the queue head node is the same as the name of +/// the operator, in order to handle such scenarios: +/// w1 w2 w3 w4 w5 +/// | | | | | +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// | | | | | +/// node1 node2 node3 node4 node5 +/// +/// Memory distribution: +/// +/// |____w1_____|__| +/// +/// |____w2_____|__| +/// +/// |____w3_____|__| +/// +/// |______w4______| +/// +/// |______w5______| +/// +/// In this scenario, prefetch2 depends on node1. If the dependency is handled by adding an event of node1 to prefetch2, +/// the id sent by prefetch2 will be the same as the id it receives.Although Runtime supports this through WaitReset, +/// we consider this a dangerous operation and avoid it. +uint32_t BufferPoolMemoryPass::GenerateEventId(const std::string &node_name, + std::queue> &event_queue) { + uint32_t logic_event = logic_event_num_; + if (!event_queue.empty()) { + auto item = event_queue.front(); + if (item.first != node_name) { + logic_event = item.second; + event_queue.pop(); + return logic_event; + } + } + ++logic_event_num_; + return logic_event; +} + +NodePtr BufferPoolMemoryPass::GetOffsetAndDependency(int64_t &next_start, + int64_t total_mem_size, + int64_t buffer_pool_size, + const std::unordered_map &buffer_node_to_calc, + std::queue &nodes_in_buffer) { + // The buffer pool can no longer fit this Tensor and needs to turn back. + if (next_start + total_mem_size > buffer_pool_size) { + next_start = 0; + if (!nodes_in_buffer.empty()) { + // Take up the rest of the space at the end, + nodes_in_buffer.back().offset_end = buffer_pool_size; + // Pop the first tensor memory in the previous round of the previous round. + nodes_in_buffer.pop(); + } + while (!nodes_in_buffer.empty()) { + auto node_item = nodes_in_buffer.front(); + // Go to the begin of previous round. + if (node_item.offset_start == 0) { + break; + } + nodes_in_buffer.pop(); + } + } + + while (!nodes_in_buffer.empty()) { + auto node_item = nodes_in_buffer.front(); + if (next_start + total_mem_size <= node_item.offset_end) { + auto pool_node = node_item.node; + if (pool_node == nullptr) { + return nullptr; + } + auto output_calc = buffer_node_to_calc.find(pool_node); + if (output_calc != buffer_node_to_calc.end()) { + return output_calc->second; + } + return nullptr; + } + nodes_in_buffer.pop(); + } + return nullptr; +} + +Status BufferPoolMemoryPass::FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node) { + // The previous process ensures that all pointers are not null. + bool not_change = false; + Status ret = TryToFixNodeOrder(dependent_calc_node, curr_pool_node, not_change); + if (ret != SUCCESS) { + GELOGE(ret, "[Fix][NodeOrder]Src:%s, dst:%s.", + dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str()); + return ret; + } + if (not_change) { + return SUCCESS; + } + uint32_t logic_event = GenerateEventId(dependent_calc_node->GetName(), mem_ctrl_event_); + node_event_multiplexing_[curr_pool_node].push_back(string("RecvFrom;" + dependent_calc_node->GetName() + + ";" + std::to_string(logic_event))); + stream_ctrl_event_.push(std::make_pair(curr_pool_node->GetName(), logic_event)); + GELOGI("[Fix][Timing]Add ctrl edge for buffer pool memory from %s to %s, buffer pool node recv event:%u.", + dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str(), logic_event); + return SUCCESS; +} + +Status BufferPoolMemoryPass::SetResultOfMemoryAndEvent() { + for (auto &iter : node_event_multiplexing_) { + auto node = iter.first; + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + bool ret = AttrUtils::SetListStr(op_desc, ATTR_NAME_EVENT_MULTIPLEXING, iter.second); + if (!ret) { + GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to set event reuse info, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + auto offset_iter = buffer_node_logical_offset_.find(node); + if (offset_iter == buffer_node_logical_offset_.end()) { + GELOGE(INTERNAL_ERROR, "[Get][LogicalOffset]Node:%s.", node->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Failed to get logical offset and size, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + ret = AttrUtils::SetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, offset_iter->second); + if (!ret) { + GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to set node memory offset and size, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/buffer_pool_memory_pass.h b/ge/graph/passes/buffer_pool_memory_pass.h new file mode 100644 index 00000000..e3d1c159 --- /dev/null +++ b/ge/graph/passes/buffer_pool_memory_pass.h @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ +#define GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ + +#include +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class BufferPoolMemoryPass : public GraphPass { + public: + explicit BufferPoolMemoryPass() : logic_event_num_(0) {} + + ~BufferPoolMemoryPass() override = default; + + struct BufferPool { + int64_t pool_id = 0; + int64_t pool_size = 0; + std::unordered_map buffer_node_to_calc; + BufferPool(int64_t id, int64_t size, const std::unordered_map &node_map) + : pool_id(id), pool_size(size), buffer_node_to_calc(node_map) {} + }; + + struct BufferPoolNodeItem { + NodePtr node = nullptr; + NodePtr out_calc_node = nullptr; + NodePtr pre_buffer_pool_node = nullptr; + int64_t total_size = 0; + int64_t offset_start = 0; + int64_t offset_end = 0; + bool is_last_input = true; + BufferPoolNodeItem(const NodePtr &buffer_n, const NodePtr &calc_n, const NodePtr &pre_buffer_n, + int64_t size, int64_t start, int64_t end, bool last) + : node(std::move(buffer_n)), + out_calc_node(std::move(calc_n)), + pre_buffer_pool_node(std::move(pre_buffer_n)), + total_size(size), + offset_start(start), + offset_end(end), + is_last_input(last) {} + + BufferPoolNodeItem(const NodePtr &buffer_n, int64_t start, int64_t end) + : node(std::move(buffer_n)), + out_calc_node(nullptr), + pre_buffer_pool_node(nullptr), + total_size(0), + offset_start(start), + offset_end(end), + is_last_input(true) {} + }; + + Status Run(ComputeGraphPtr graph) override; + + private: + static void ClearQueue(std::queue> &q); + + static Status IsBufferPoolMemEnable(const ComputeGraphPtr &graph); + + static Status CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size, + std::unordered_map &calc_total_size); + + static Status TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool ¬_change); + + Status InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node); + + Status CopyOutForMultiUsedOutput(ComputeGraphPtr &graph); + + Status GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph); + + Status SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size); + + Status AllocateAllBufferPoolSpace(); + + Status AllocateSpaceInBatch(const std::map> &calc_nodes, + const std::unordered_map &buffer_pool_size_map, + const std::unordered_map &buffer_node_to_calc, + std::unordered_map> &buffer_pool_nodes_item); + + Status AllocateSpaceInBufferPool(const BufferPool &buffer_pool, + const std::vector &calc_nodes_in_pool, + std::unordered_map> &buffer_pool_nodes_item); + + Status AllocateSpaceForBufferPoolNode(int64_t &next_start, + const BufferPool buffer_pool, + BufferPoolNodeItem &buffer_pool_node_item, + std::queue &node_mem_range_in_pool); + + NodePtr GetOffsetAndDependency(int64_t &next_start, + int64_t total_mem_size, + int64_t buffer_pool_size, + const std::unordered_map &buffer_node_to_calc, + std::queue &nodes_in_buffer); + + Status FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node); + + uint32_t GenerateEventId(const std::string &node_name, std::queue> &event_queue); + + Status SetResultOfMemoryAndEvent(); + + // Use map to ensure that each visit is in the order of batch label and pool id + std::map>> calc_nodes_; + + std::unordered_map> buffer_node_to_calc_; + + std::unordered_map>> peer_buffer_node_item_; + + std::unordered_map> buffer_pool_size_; + + uint32_t logic_event_num_; + + std::queue> mem_ctrl_event_; + + std::queue> stream_ctrl_event_; + + std::unordered_map> node_event_multiplexing_; + + std::unordered_map> buffer_node_logical_offset_; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index b062ec80..00873b8f 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -43,6 +43,11 @@ rtError_t rtEventCreate(rtEvent_t *event) { *event = new int[EVENT_LENTH]; return RT_ERROR_NONE; } + +rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag) { + return rtEventCreate(event); +} + rtError_t rtEventRecord(rtEvent_t event, rtStream_t stream) { return RT_ERROR_NONE; } rtError_t rtEventSynchronize(rtEvent_t event) { return RT_ERROR_NONE; } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index fcb1e6aa..dbfc93a1 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -276,6 +276,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" @@ -323,6 +324,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/build/memory/block_mem_assigner.cc" "${GE_CODE_DIR}/ge/graph/build/memory/binary_block_mem_assigner.cc" "${GE_CODE_DIR}/ge/graph/build/memory/max_block_mem_assigner.cc" + "${GE_CODE_DIR}/ge/graph/build/memory/buffer_pool_mem_assigner.cc" "${GE_CODE_DIR}/ge/graph/manager/graph_mem_allocator.cc" "${GE_CODE_DIR}/ge/graph/manager/graph_var_manager.cc" "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" @@ -627,6 +629,7 @@ set(SINGLE_OP_SRC_FILES # test files set(COMMON_TEST_FILES "graph/passes/graph_builder_utils.cc" + "graph/utils/buffer_pool_graph_builder.cc" "test.cc" ) @@ -703,6 +706,7 @@ set(PASS_TEST_FILES "graph/passes/link_gen_mask_nodes_pass_unittest.cc" "graph/passes/transpose_transdata_pass_unittest.cc" "graph/passes/parallel_group_pass_unittest.cc" + "graph/passes/buffer_pool_memory_pass_unittest.cc" ) set(KERNEL_TEST_FILES @@ -771,6 +775,7 @@ set(MULTI_PARTS_TEST_FILES "graph/build/model_builder_unittest.cc" "graph/build/mem_assigner_unittest.cc" "graph/build/task_generator_unittest.cc" + "graph/build/buffer_pool_mem_assigner_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc" "graph/manager/hcom_util_unittest.cc" "graph/manager/graph_caching_allocator_unittest.cc" diff --git a/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc new file mode 100644 index 00000000..96283250 --- /dev/null +++ b/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc @@ -0,0 +1,607 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/ge_inner_error_codes.h" +#include "common/types.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "../utils/buffer_pool_graph_builder.h" +#include "graph/passes/buffer_pool_memory_pass.h" + +#define protected public +#define private public +#include "graph/build/memory/buffer_pool_mem_assigner.h" +#include "graph/build/memory/graph_mem_assigner.h" +#include "graph/build/stream_allocator.h" +#undef protected +#undef private + +namespace ge { +namespace { +const int64_t kMemoryTypeHBM = static_cast(RT_MEMORY_HBM); +const int64_t kMemoryTypeP2P = static_cast(RT_MEMORY_P2P_HBM); +const int64_t kMemoryTypeDDR = static_cast(RT_MEMORY_DDR); +const size_t kOffsetHBM = 10240; +const size_t kOffsetP2P = 20480; +const size_t kOffsetDDR = 30720; +const int64_t kMemAlignSize = 512; + +int64_t AlignMemSize(int64_t mem_size, int64_t align_size = kMemAlignSize) { + int64_t tmp = (mem_size + align_size - 1) / align_size * align_size; + return tmp; +} +int64_t AlignOutputMemSize(int64_t mem_size) { + int64_t tmp = (mem_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize; + // hccl need alignment + tmp = kMemAlignSize + tmp + kMemAlignSize; + return tmp; +} +} // namespace +class UtestBufferPoolMemAssignerTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} + +}; + +TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_normal_assign_success) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + std::map mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, + {kMemoryTypeP2P, kOffsetP2P}}; + int64_t offset_base = static_cast(kOffsetHBM + kMemAlignSize); + std::vector expect_offset = {(offset_base + 0), + (offset_base + AlignOutputMemSize(500)), + (offset_base + (AlignOutputMemSize(500) * 2)), + (offset_base + 0), + (offset_base + AlignOutputMemSize(1024))}; + + BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); + ret = buffer_pool_mem_assigner.Assign(); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + + AlignMemSize(5600, kMemAlignSize) + kMemAlignSize); + + { + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); + } + + { + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); + } + + { + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); + } + + { + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); + } + + { + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); + } +} + +TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_normal_graph_with_multi_buffer_pool_assign_success) { + ut::BufferPoolGraphBuilder builder("NormalGraphWithMultiBufferPool"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraphWithMultiBufferPool(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + std::map mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, + {kMemoryTypeP2P, kOffsetP2P}}; + int64_t offset_base_0 = static_cast(kOffsetHBM + kMemAlignSize); + int64_t offset_base_1 = static_cast(kOffsetHBM + kMemAlignSize) + + AlignMemSize(5000, kMemAlignSize) + kMemAlignSize; + std::vector expect_offset = {(offset_base_0 + 0), + (offset_base_1 + 0), + (offset_base_0 + AlignOutputMemSize(500)), + (offset_base_0 + 0), + (offset_base_1 + AlignOutputMemSize(500))}; + + BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); + ret = buffer_pool_mem_assigner.Assign(); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base_1 + + AlignMemSize(5000, kMemAlignSize) + kMemAlignSize); + + { + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); + } + + { + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); + } + + { + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); + } + + { + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); + } + + { + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); + } +} + +TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success) { + ut::BufferPoolGraphBuilder builder("SerialGraph"); + ge::ComputeGraphPtr graph = builder.BuildSerialGraph(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + std::map mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, + {kMemoryTypeP2P, kOffsetP2P}}; + int64_t offset_base = static_cast(kOffsetHBM + kMemAlignSize); + std::vector expect_offset = {offset_base, offset_base, offset_base, offset_base, offset_base}; + + BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); + ret = buffer_pool_mem_assigner.Assign(); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + + AlignMemSize(2048, kMemAlignSize) + kMemAlignSize); + + { + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); + } + + { + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); + } + + { + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); + } + + { + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); + } + + { + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); + } +} + +TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) { + ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); + ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + std::map mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, + {kMemoryTypeP2P, kOffsetP2P}}; + int64_t offset_base = static_cast(kOffsetHBM + kMemAlignSize); + std::vector expect_offset = {(offset_base + 0), + (offset_base + AlignOutputMemSize(500)), + (offset_base + (AlignOutputMemSize(500) * 2)), + (offset_base + 0), + (offset_base + AlignOutputMemSize(1024))}; + + BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); + ret = buffer_pool_mem_assigner.Assign(); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + + AlignMemSize(5600, kMemAlignSize) + kMemAlignSize); + + std::map all_nodes; + for (auto node : graph->GetAllNodes()) { + EXPECT_NE(node, nullptr); + all_nodes[node->GetName()] = node; + } + + { + auto prefetch = all_nodes.at("prefetch1"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); + } + + { + auto prefetch = all_nodes.at("prefetch2"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); + } + + { + auto prefetch = all_nodes.at("prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); + } + + { + auto prefetch = all_nodes.at("prefetch4"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); + } + + { + auto prefetch = all_nodes.at("prefetch5"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); + } +} + +TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_graph_with_multi_batch_assign_success) { + ut::BufferPoolGraphBuilder builder("GraphWithMultiBatch"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiBatch(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + std::map mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, + {kMemoryTypeP2P, kOffsetP2P}}; + int64_t offset_base = static_cast(kOffsetHBM + kMemAlignSize); + std::vector expect_offset = {(offset_base + 0), + (offset_base + AlignOutputMemSize(500)), + (offset_base + (AlignOutputMemSize(500) * 2)), + (offset_base + 0), + (offset_base + AlignOutputMemSize(1024))}; + + BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); + ret = buffer_pool_mem_assigner.Assign(); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + + AlignMemSize(5600, kMemAlignSize) + kMemAlignSize); + + { + auto prefetch = graph->FindNode("batch_label_128/prefetch1"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); + } + + { + auto prefetch = graph->FindNode("batch_label_128/prefetch2"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); + } + + { + auto prefetch = graph->FindNode("batch_label_128/prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); + } + + { + auto prefetch = graph->FindNode("batch_label_128/prefetch4"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); + } + + { + auto prefetch = graph->FindNode("batch_label_128/prefetch5"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); + } + + { + auto prefetch = graph->FindNode("batch_label_256/prefetch1"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); + } + + { + auto prefetch = graph->FindNode("batch_label_256/prefetch2"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); + } + + { + auto prefetch = graph->FindNode("batch_label_256/prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); + } + + { + auto prefetch = graph->FindNode("batch_label_256/prefetch4"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); + } + + { + auto prefetch = graph->FindNode("batch_label_256/prefetch5"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector output_offset = prefetch->GetOpDesc()->GetOutputOffset(); + EXPECT_EQ(output_offset.size(), 1); + EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); + } +} + +TEST_F(UtestBufferPoolMemAssignerTest, test_AssignBufferPoolMemory_success) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + std::map memory_offset = {{kMemoryTypeHBM, MemoryOffset(RT_MEMORY_HBM, kOffsetHBM)}, + {kMemoryTypeP2P, MemoryOffset(RT_MEMORY_P2P_HBM, kOffsetP2P)}}; + + GraphMemoryAssigner graph_memory_assigner(graph); + graph_memory_assigner.memory_offset_ = memory_offset; + ret = graph_memory_assigner.AssignBufferPoolMemory(); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestBufferPoolMemAssignerTest, test_AssignBufferPoolMemory_fail) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + std::map memory_offset = {{kMemoryTypeHBM, MemoryOffset(RT_MEMORY_HBM, kOffsetHBM)}, + {kMemoryTypeP2P, MemoryOffset(RT_MEMORY_P2P_HBM, kOffsetP2P)}}; + { + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + std::vector type_list = {static_cast(RT_MEMORY_P2P_HBM)}; + bool set_attr = ge::AttrUtils::SetListInt(prefetch->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list); + EXPECT_EQ(set_attr, true); + + GraphMemoryAssigner graph_memory_assigner(graph); + graph_memory_assigner.memory_offset_ = memory_offset; + Status ret = graph_memory_assigner.AssignBufferPoolMemory(); + EXPECT_EQ(ret, FAILED); + } + + { + std::vector node_list = {"prefetch1", "prefetch2", "prefetch3", "prefetch4", "prefetch5"}; + std::vector type_list = {static_cast(RT_MEMORY_L1)}; + for (auto &node_name : node_list) { + auto prefetch = graph->FindNode(node_name); + EXPECT_NE(prefetch, nullptr); + EXPECT_NE(prefetch->GetOpDesc(), nullptr); + bool set_attr = ge::AttrUtils::SetListInt(prefetch->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list); + EXPECT_EQ(set_attr, true); + } + GraphMemoryAssigner graph_memory_assigner(graph); + graph_memory_assigner.memory_offset_ = memory_offset; + Status ret = graph_memory_assigner.AssignBufferPoolMemory(); + EXPECT_EQ(ret, FAILED); + } +} + +TEST_F(UtestBufferPoolMemAssignerTest, test_RefreshEventsWithReuse_success) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + std::map all_nodes; + for (auto node : graph->GetAllNodes()) { + EXPECT_NE(node, nullptr); + all_nodes[node->GetName()] = node; + } + + Graph2SubGraphInfoList sub_graphs; + StreamAllocator stream_allocator(graph, sub_graphs); + stream_allocator.event_num_ = 65520; + + // stream ctrl event + stream_allocator.AddSendEventId(all_nodes.at("prefetch1"), 30); + stream_allocator.AddRecvEventId(all_nodes.at("add1"), 30); + + stream_allocator.AddSendEventId(all_nodes.at("prefetch2"), 31); + stream_allocator.AddRecvEventId(all_nodes.at("add2"), 31); + + stream_allocator.AddSendEventId(all_nodes.at("prefetch3"), 32); + stream_allocator.AddRecvEventId(all_nodes.at("add3"), 32); + + stream_allocator.AddSendEventId(all_nodes.at("prefetch4"), 33); + stream_allocator.AddRecvEventId(all_nodes.at("add4"), 33); + + stream_allocator.AddSendEventId(all_nodes.at("add2"), 34); + stream_allocator.AddRecvEventId(all_nodes.at("prefetch4"), 34); + + stream_allocator.AddSendEventId(all_nodes.at("prefetch5"), 35); + stream_allocator.AddRecvEventId(all_nodes.at("add5"), 35); + + stream_allocator.AddSendEventId(all_nodes.at("add3"), 36); + stream_allocator.AddRecvEventId(all_nodes.at("prefetch5"), 36); + + // other event + stream_allocator.AddSendEventId(all_nodes.at("prefetch1"), 37); + stream_allocator.AddRecvEventId(all_nodes.at("add5"), 37); + + + ret = stream_allocator.RefreshEventsWithReuse(); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ((stream_allocator.node_to_send_events_.at(all_nodes.at("prefetch1"))).size(), 2); + EXPECT_EQ((stream_allocator.node_to_send_events_.at(all_nodes.at("prefetch5"))).size(), 1); + EXPECT_EQ((stream_allocator.node_to_recv_events_.at(all_nodes.at("prefetch5"))).size(), 1); + EXPECT_EQ((stream_allocator.node_to_recv_events_.at(all_nodes.at("add5"))).size(), 2); + EXPECT_EQ(stream_allocator.event_num_, 5); +} + +TEST_F(UtestBufferPoolMemAssignerTest, test_RefreshEventsWithReuse_fail) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + + std::map all_nodes; + for (auto node : graph->GetAllNodes()) { + EXPECT_NE(node, nullptr); + all_nodes[node->GetName()] = node; + } + std::vector> event_info = {{"SendTo;add1;0"}, + {"SendTo;add2;1"}, + {"SendTo;add3;2"}, + {"SendTo;add4;3", "RecvFrom;add2;0"}, + {"SendTo;add5;0", "RecvFrom;add3;1"}}; + + (void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); + (void) AttrUtils::SetListStr(all_nodes.at("prefetch2")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[1]); + (void) AttrUtils::SetListStr(all_nodes.at("prefetch3")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[2]); + (void) AttrUtils::SetListStr(all_nodes.at("prefetch4")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[3]); + (void) AttrUtils::SetListStr(all_nodes.at("prefetch5")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[4]); + + Graph2SubGraphInfoList sub_graphs; + StreamAllocator stream_allocator(graph, sub_graphs); + stream_allocator.event_num_ = 65520; + + // Item num of raw event info is invalid + event_info[0][0] = "SendTo;add1;0;1"; + (void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); + Status ret = stream_allocator.RefreshEventsWithReuse(); + EXPECT_EQ(ret, PARAM_INVALID); + + // Event id is invalid argument + event_info[0][0] = "SendTo;add1;event_id"; + (void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); + ret = stream_allocator.RefreshEventsWithReuse(); + EXPECT_EQ(ret, PARAM_INVALID); + + // Event id is out of range + event_info[0][0] = "SendTo;add1;666666666666666666666666666666666666666"; + (void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); + ret = stream_allocator.RefreshEventsWithReuse(); + EXPECT_EQ(ret, PARAM_INVALID); + + // Event id is negative + event_info[0][0] = "SendTo;add1;-2"; + (void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); + ret = stream_allocator.RefreshEventsWithReuse(); + EXPECT_EQ(ret, PARAM_INVALID); + + // Key word is not supported + event_info[0][0] = "SendToKey;add1;2"; + (void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); + ret = stream_allocator.RefreshEventsWithReuse(); + EXPECT_EQ(ret, PARAM_INVALID); +} +} // namespace ge + diff --git a/tests/ut/ge/graph/passes/buffer_pool_memory_pass_unittest.cc b/tests/ut/ge/graph/passes/buffer_pool_memory_pass_unittest.cc new file mode 100644 index 00000000..a59ca54f --- /dev/null +++ b/tests/ut/ge/graph/passes/buffer_pool_memory_pass_unittest.cc @@ -0,0 +1,591 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/ge_inner_error_codes.h" +#include "common/types.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "inc/pass_manager.h" +#include "graph_builder_utils.h" +#include "../utils/buffer_pool_graph_builder.h" +#include "graph/passes/buffer_pool_memory_pass.h" + +namespace ge { +class UtestBufferPoolMemoryPass : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_normal_success_test) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add4;3"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add5;0"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add3;1"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_normal_graph_with_multi_buffer_pool_success_test) { + ut::BufferPoolGraphBuilder builder("NormalGraphWithMultiBufferPool"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraphWithMultiBufferPool(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add2;3"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add3;1"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add4;2"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add3;0"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add5;4"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_contain_one_node_success_test) { + ut::BufferPoolGraphBuilder builder("SerialGraph"); + ge::ComputeGraphPtr graph = builder.BuildSerialGraph(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add1;2"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add1"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add4;0"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add3;1"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add5;1"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add4;2"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add4"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, calc_node_with_multi_buffer_pool_input_success_test) { + ut::BufferPoolGraphBuilder builder("GraphWithMultiPrefetch"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiPrefetch(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 0); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 0); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add1;2"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add1"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_in_different_subgraph_success_test) { + ut::BufferPoolGraphBuilder builder("GraphWithSubgraph"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithSubgraph(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + std::map all_nodes; + for (auto node : graph->GetAllNodes()) { + EXPECT_NE(node, nullptr); + all_nodes[node->GetName()] = node; + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add4;3"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 0); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add5;4"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 1); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch4"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_in_different_subgraph_with_inner_dependency_success_test) { + ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); + ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + std::map all_nodes; + for (auto node : graph->GetAllNodes()) { + EXPECT_NE(node, nullptr); + all_nodes[node->GetName()] = node; + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;add4;3"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 1); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch3"); + } + + { + std::vector event_info; + auto prefetch = all_nodes.at("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add5;4"); + EXPECT_EQ(event_info.at(1), "RecvFrom;add3;0"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_with_batch_label_success_test) { + ut::BufferPoolGraphBuilder builder("GraphWithMultiBatch"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiBatch(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + { + std::vector event_info; + auto prefetch = graph->FindNode("batch_label_256/prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add1;4"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("batch_label_256/prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add2;5"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("batch_label_256/prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add3;6"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("batch_label_256/prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add4;7"); + EXPECT_EQ(event_info.at(1), "RecvFrom;batch_label_256/add2;4"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "batch_label_256/add2"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("batch_label_256/prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add5;4"); + EXPECT_EQ(event_info.at(1), "RecvFrom;batch_label_256/add3;5"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "batch_label_256/add3"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_node_has_multi_output_success_test) { + ut::BufferPoolGraphBuilder builder("GraphWithMultiOutputPrefetch"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiOutputPrefetch(); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch1"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;prefetch1_memcpy_async;0"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch2"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;prefetch2_memcpy_async;1"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 1); + EXPECT_EQ(event_info.at(0), "SendTo;prefetch3_memcpy_async;2"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch4"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;prefetch4_memcpy_async;3"); + EXPECT_EQ(event_info.at(1), "RecvFrom;prefetch2_memcpy_async;0"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch2_memcpy_async"); + } + + { + std::vector event_info; + auto prefetch = graph->FindNode("prefetch5"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); + EXPECT_EQ(event_info.size(), 2); + EXPECT_EQ(event_info.at(0), "SendTo;add5;0"); + EXPECT_EQ(event_info.at(1), "RecvFrom;prefetch3_memcpy_async;1"); + auto in_ctrl_nodes = prefetch->GetInControlNodes(); + EXPECT_EQ(in_ctrl_nodes.size(), 2); + EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch3_memcpy_async"); + } +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_has_different_size_fail_test) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + const int64_t dummy_size = 256; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + (void) AttrUtils::SetInt(prefetch->GetOpDesc(), "_buffer_pool_size", dummy_size); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, FAILED); +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_size_is_not_enough_fail_test) { + ut::BufferPoolGraphBuilder builder("NormalGraph"); + ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + builder.SetPrefetchNodeInfo(prefetch, buffer_pool_id, buffer_pool_size, {buffer_pool_size + 512}); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, FAILED); +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_size_is_not_enough_for_multi_fail_test) { + ut::BufferPoolGraphBuilder builder("GraphWithMultiPrefetch"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiPrefetch(); + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + auto prefetch = graph->FindNode("prefetch3"); + EXPECT_NE(prefetch, nullptr); + builder.SetPrefetchNodeInfo(prefetch, buffer_pool_id, buffer_pool_size, {buffer_pool_size}); + + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, FAILED); +} + +TEST_F(UtestBufferPoolMemoryPass, buffer_pool_node_has_multi_input_output_fail_test) { + ut::BufferPoolGraphBuilder builder("GraphWithMultiInputOutputPrefetch"); + ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiInputOutputPrefetch(); + BufferPoolMemoryPass buffer_pool_mem_pass; + Status ret = buffer_pool_mem_pass.Run(graph); + EXPECT_EQ(ret, FAILED); +} +} // namespace ge diff --git a/tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc b/tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc new file mode 100644 index 00000000..dd52f287 --- /dev/null +++ b/tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc @@ -0,0 +1,978 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "buffer_pool_graph_builder.h" +#include "common/ge_inner_error_codes.h" +#include "common/types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +namespace ut { +BufferPoolGraphBuilder::BufferPoolGraphBuilder(const std::string &name) { + graph_name_ = name; +} + +BufferPoolGraphBuilder::InnerGraphBuilder::InnerGraphBuilder(const std::string &name) { + graph_ = std::make_shared(name); + EXPECT_NE(graph_, nullptr); +} + +NodePtr BufferPoolGraphBuilder::InnerGraphBuilder::AddNode(const std::string &name, const std::string &type, + int in_cnt, int out_cnt, + Format format, DataType data_type, + std::vector shape) { + auto tensor_desc = std::make_shared(); + EXPECT_NE(tensor_desc, nullptr); + tensor_desc->SetShape(GeShape(std::move(shape))); + tensor_desc->SetFormat(format); + tensor_desc->SetDataType(data_type); + auto op_desc = std::make_shared(name, type); + EXPECT_NE(op_desc, nullptr); + for (int i = 0; i < in_cnt; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < out_cnt; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + return graph_->AddNode(op_desc); +} + +void BufferPoolGraphBuilder::InnerGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, + NodePtr &dst_node, int dst_idx) { + EXPECT_NE(src_node, nullptr); + EXPECT_NE(dst_node, nullptr); + GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); +} + +void BufferPoolGraphBuilder::InnerGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { + EXPECT_NE(src_node, nullptr); + EXPECT_NE(dst_node, nullptr); + GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); +} + +void BufferPoolGraphBuilder::SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, + const std::string &batch_label) { + EXPECT_NE(node, nullptr); + (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, pool_id); + (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, pool_size); + if (!batch_label.empty()) { + (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); + } +} + +void BufferPoolGraphBuilder::SetBatchLabel(NodePtr &node, const std::string &batch_label) { + EXPECT_NE(node, nullptr); + (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); + +} + +void BufferPoolGraphBuilder::SetOutputMemSize(NodePtr &node, const std::vector &mem_size) { + EXPECT_NE(node, nullptr); + EXPECT_NE(node->GetOpDesc(), nullptr); + size_t output_size = node->GetOpDesc()->GetOutputsSize(); + EXPECT_EQ(output_size, mem_size.size()); + for (size_t i = 0; i < output_size; ++i) { + auto output_op_desc = node->GetOpDesc()->MutableOutputDesc(i); + ge::TensorUtils::SetSize(*output_op_desc, mem_size[i]); + } +} + +void BufferPoolGraphBuilder::SetWorkSpaceMemSize(NodePtr &node, const std::vector &ws_bytes) { + EXPECT_NE(node, nullptr); + EXPECT_NE(node->GetOpDesc(), nullptr); + node->GetOpDesc()->SetWorkspaceBytes(ws_bytes); +} + +void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size, + const std::vector &mem_size, + const std::vector &ws_bytes, + const std::string &batch_label) { + SetBufferPool(node, pool_id, pool_size, batch_label); + SetOutputMemSize(node, mem_size); + SetWorkSpaceMemSize(node, ws_bytes); +} + +/// +/// Normal graph +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// \ \ \ \ \ +/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output +/// +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); + auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto add2 = builder.AddNode("add2", ADD, 2, 1); + auto add3 = builder.AddNode("add3", ADD, 2, 1); + auto add4 = builder.AddNode("add4", ADD, 2, 1); + auto add5 = builder.AddNode("add5", ADD, 2, 1); + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(add1, 0, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(add2, 0, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(add3, 0, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(add4, 0, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add5, 0, net_output, 0); + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} + +/// +/// Normal graph with multi buffer pool +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// (pool0) (pool1) (pool0) (pool0) (pool1) +/// \ \ \ \ \ + /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output +/// +/// +/// Memory distribution: +/// +/// |___w1__|__w3__|_________| +/// |_____w4_____|___________| +/// +/// |___w2__|_____w5___|_____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + const int64_t buffer_pool_id_0 = 0; + const int64_t buffer_pool_id_1 = 1; + const int64_t buffer_pool_size = 5000; + + auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id_0, buffer_pool_size, {500}); + auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id_1, buffer_pool_size, {500}); + auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id_0, buffer_pool_size, {500}); + auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id_0, buffer_pool_size, {1024}); + auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id_1, buffer_pool_size, {1024}); + + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto add2 = builder.AddNode("add2", ADD, 2, 1); + auto add3 = builder.AddNode("add3", ADD, 2, 1); + auto add4 = builder.AddNode("add4", ADD, 2, 1); + auto add5 = builder.AddNode("add5", ADD, 2, 1); + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(add1, 0, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(add2, 0, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(add3, 0, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(add4, 0, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add5, 0, net_output, 0); + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} + +/// +/// SerialGraph: Buffer pool size only can contain one prefetch node +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// \ \ \ \ \ +/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output +/// +/// +/// Memory distribution: +/// +/// |____w1_____|__| +/// +/// |____w2_____|__| +/// +/// |____w3_____|__| +/// +/// |______w4______| +/// +/// |______w5______| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 2048; + + auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); + auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto add2 = builder.AddNode("add2", ADD, 2, 1); + auto add3 = builder.AddNode("add3", ADD, 2, 1); + auto add4 = builder.AddNode("add4", ADD, 2, 1); + auto add5 = builder.AddNode("add5", ADD, 2, 1); + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(add1, 0, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(add2, 0, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(add3, 0, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(add4, 0, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add5, 0, net_output, 0); + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} + +/// +/// GraphWithMultiPrefetch: Calc node with more prefetch node +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 +/// \ / \ / \ / +/// \ / \ / \ / +/// \ / \ / \ / +/// add1 ------ c ------- add2 ----- c ----- add3 +/// | | | +/// | | | +/// --------------- net_output ------------ +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); + auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto add2 = builder.AddNode("add2", ADD, 2, 1); + auto add3 = builder.AddNode("add3", ADD, 2, 1); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 3, 0); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(prefetch1, 0, add1, 0); + builder.AddDataEdge(prefetch2, 0, add1, 1); + + builder.AddDataEdge(prefetch3, 0, add2, 0); + builder.AddDataEdge(prefetch4, 0, add2, 1); + + builder.AddDataEdge(const1, 0, add3, 0); + builder.AddDataEdge(prefetch5, 0, add3, 1); + + builder.AddDataEdge(add1, 0, net_output, 0); + builder.AddDataEdge(add2, 0, net_output, 1); + builder.AddDataEdge(add3, 0, net_output, 2); + + builder.AddControlEdge(add1, add2); + builder.AddControlEdge(add2, add3); + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} + +/// +/// GraphWithSubgraph: Calc node in different subgraph +/// +/// +/// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output +/// +/// +/// Subgraph1: Subgraph2: +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// \ \ \ \ \ +/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out +/// +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() { + auto builder = InnerGraphBuilder(graph_name_); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + // Subgraph1 + auto subgraph_builder1 = InnerGraphBuilder("Subgraph1"); + auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1); + auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1); + auto w3 = subgraph_builder1.AddNode("w3", VARIABLE, 0, 1); + + auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch3 = subgraph_builder1.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); + auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0); + auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1); + + auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1); + auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1); + auto add3 = subgraph_builder1.AddNode("add3", ADD, 2, 1); + + subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0); + subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0); + subgraph_builder1.AddDataEdge(w3, 0, prefetch3, 0); + subgraph_builder1.AddDataEdge(const1, 0, add1, 0); + subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1); + subgraph_builder1.AddDataEdge(add1, 0, add2, 0); + subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1); + subgraph_builder1.AddDataEdge(add2, 0, add3, 0); + subgraph_builder1.AddDataEdge(prefetch3, 0, add3, 1); + subgraph_builder1.AddDataEdge(add3, 0, subgraph1_out, 0); + auto subgraph1 = subgraph_builder1.GetGraph(); + for (auto &node : subgraph1->GetDirectNode()) { + node->SetOwnerComputeGraph(subgraph1); + } + + // Subgraph2 + auto subgraph_builder2 = InnerGraphBuilder("Subgraph2"); + auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1); + auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1); + + auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); + auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1); + auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1); + auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1); + auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1); + + subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0); + subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0); + subgraph_builder2.AddDataEdge(data1, 0, add4, 0); + subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1); + subgraph_builder2.AddDataEdge(add4, 0, add5, 0); + subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1); + subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0); + + auto subgraph2 = subgraph_builder2.GetGraph(); + for (auto &node : subgraph2->GetDirectNode()) { + node->SetOwnerComputeGraph(subgraph2); + } + + // root graph + auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1); + auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); + builder.AddDataEdge(call_node1, 0, call_node2, 0); + builder.AddDataEdge(call_node2, 0, net_output, 0); + auto compute_graph = builder.GetGraph(); + call_node1->SetOwnerComputeGraph(compute_graph); + call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName()); + call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); + call_node2->SetOwnerComputeGraph(compute_graph); + call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName()); + call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName()); + + subgraph1->SetParentNode(call_node1); + subgraph1->SetParentGraph(compute_graph); + subgraph2->SetParentNode(call_node2); + subgraph2->SetParentGraph(compute_graph); + compute_graph->AddSubGraph(subgraph1); + compute_graph->AddSubGraph(subgraph2); + + return compute_graph; +} + +/// +/// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency +/// +/// +/// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output +/// +/// +/// Subgraph1: Subgraph2: +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// \ \ \ \ \ +/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out +/// +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() { + auto builder = InnerGraphBuilder(graph_name_); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + // Subgraph1 + auto subgraph_builder1 = InnerGraphBuilder("Subgraph1"); + auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1); + auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1); + + auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); + auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0); + auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1); + + auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1); + auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1); + + subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0); + subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0); + subgraph_builder1.AddDataEdge(const1, 0, add1, 0); + subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1); + subgraph_builder1.AddDataEdge(add1, 0, add2, 0); + subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1); + subgraph_builder1.AddDataEdge(add2, 0, subgraph1_out, 0); + auto subgraph1 = subgraph_builder1.GetGraph(); + for (auto &node : subgraph1->GetDirectNode()) { + node->SetOwnerComputeGraph(subgraph1); + } + + // Subgraph2 + auto subgraph_builder2 = InnerGraphBuilder("Subgraph2"); + auto w3 = subgraph_builder2.AddNode("w3", VARIABLE, 0, 1); + auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1); + auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1); + + auto prefetch3 = subgraph_builder2.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); + auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto add3 = subgraph_builder2.AddNode("add3", ADD, 2, 1); + auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1); + auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1); + auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1); + auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1); + + subgraph_builder2.AddDataEdge(w3, 0, prefetch3, 0); + subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0); + subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0); + subgraph_builder2.AddDataEdge(data1, 0, add3, 0); + subgraph_builder2.AddDataEdge(prefetch3, 0, add3, 1); + subgraph_builder2.AddDataEdge(add3, 0, add4, 0); + subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1); + subgraph_builder2.AddDataEdge(add4, 0, add5, 0); + subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1); + subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0); + + auto subgraph2 = subgraph_builder2.GetGraph(); + for (auto &node : subgraph2->GetDirectNode()) { + node->SetOwnerComputeGraph(subgraph2); + } + + // root graph + auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1); + auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0); + auto net_output = subgraph_builder2.AddNode("net_output", NETOUTPUT, 1, 0); + builder.AddDataEdge(call_node1, 0, call_node2, 0); + builder.AddDataEdge(call_node2, 0, net_output, 0); + auto compute_graph = builder.GetGraph(); + call_node1->SetOwnerComputeGraph(compute_graph); + call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName()); + call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); + call_node2->SetOwnerComputeGraph(compute_graph); + call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName()); + call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName()); + + subgraph1->SetParentNode(call_node1); + subgraph1->SetParentGraph(compute_graph); + subgraph2->SetParentNode(call_node2); + subgraph2->SetParentGraph(compute_graph); + compute_graph->AddSubGraph(subgraph1); + compute_graph->AddSubGraph(subgraph2); + + return compute_graph; +} + +/// +/// BuildGraphWithMultiBatch: Different batch label +/// +/// +/// batch_label_128 +/// +/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- +/// / / / / / / \ +/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \ +/// const1 switch_false / / / / / \ +/// \ / / / / / / \ +/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output +/// / \ \ \ \ \ \ / +/// const2 switch_true \ \ \ \ \ / +/// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 / +/// \ \ \ \ \ \ / +/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- +/// +/// batch_label_256 +/// +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto const2 = builder.AddNode("const2", CONSTANTOP, 0, 1); + auto switch1 = builder.AddNode("switch1", SWITCH, 2, 2); + auto switch_false = builder.AddNode("switch_false", IDENTITY, 1, 1); + auto switch_true = builder.AddNode("switch_true", IDENTITY, 1, 1); + auto merge1 = builder.AddNode("merge1", MERGE, 2, 2); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); + + builder.AddDataEdge(const1, 0, switch1, 0); + builder.AddDataEdge(const2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, switch_false, 0); + builder.AddDataEdge(switch1, 1, switch_true, 0); + builder.AddDataEdge(merge1, 0, net_output, 0); + + std::string batch_label_128 = "batch_128"; + std::string batch_label_256 = "batch_256"; + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + { + auto prefetch1 = builder.AddNode("batch_label_128/prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); + auto prefetch2 = builder.AddNode("batch_label_128/prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); + auto prefetch3 = builder.AddNode("batch_label_128/prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); + auto prefetch4 = builder.AddNode("batch_label_128/prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128); + auto prefetch5 = builder.AddNode("batch_label_128/prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128); + + auto add1 = builder.AddNode("batch_label_128/add1", ADD, 2, 1); + SetBatchLabel(add1, batch_label_128); + auto add2 = builder.AddNode("batch_label_128/add2", ADD, 2, 1); + SetBatchLabel(add2, batch_label_128); + auto add3 = builder.AddNode("batch_label_128/add3", ADD, 2, 1); + SetBatchLabel(add3, batch_label_128); + auto add4 = builder.AddNode("batch_label_128/add4", ADD, 2, 1); + SetBatchLabel(add4, batch_label_128); + auto add5 = builder.AddNode("batch_label_128/add5", ADD, 2, 1); + SetBatchLabel(add5, batch_label_128); + auto const1 = builder.AddNode("batch_label_128/const1", CONSTANTOP, 0, 1); + SetBatchLabel(const1, batch_label_128); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(add1, 0, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(add2, 0, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(add3, 0, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(add4, 0, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add5, 0, merge1, 0); + builder.AddControlEdge(switch_false, const1); + } + + { + auto prefetch1 = builder.AddNode("batch_label_256/prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); + auto prefetch2 = builder.AddNode("batch_label_256/prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); + auto prefetch3 = builder.AddNode("batch_label_256/prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); + auto prefetch4 = builder.AddNode("batch_label_256/prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256); + auto prefetch5 = builder.AddNode("batch_label_256/prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256); + + auto add1 = builder.AddNode("batch_label_256/add1", ADD, 2, 1); + SetBatchLabel(add1, batch_label_256); + auto add2 = builder.AddNode("batch_label_256/add2", ADD, 2, 1); + SetBatchLabel(add2, batch_label_256); + auto add3 = builder.AddNode("batch_label_256/add3", ADD, 2, 1); + SetBatchLabel(add3, batch_label_256); + auto add4 = builder.AddNode("batch_label_256/add4", ADD, 2, 1); + SetBatchLabel(add4, batch_label_256); + auto add5 = builder.AddNode("batch_label_256/add5", ADD, 2, 1); + SetBatchLabel(add5, batch_label_256); + auto const1 = builder.AddNode("batch_label_256/const1", CONSTANTOP, 0, 1); + SetBatchLabel(const1, batch_label_128); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(add1, 0, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(add2, 0, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(add3, 0, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(add4, 0, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add5, 0, merge1, 1); + + builder.AddControlEdge(switch_true, const1); + } + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} + +/// +/// GraphWithMultiOutputPrefetch: Prefetch has more than one output +/// +/// w1 w2 w3 w4 w5 +/// \ \ \ \ \ +/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// / \ / \ / \ / \ / +/// / \ / \ / \ / \ / +/// const1 ----- add1 add2 add3 add4 add5 +/// | \ | / | +/// | \ | / | +/// | \ | / | +/// | \ | / | +/// -------------- net_output --------------- +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); + auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); + auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto add2 = builder.AddNode("add2", ADD, 2, 1); + auto add3 = builder.AddNode("add3", ADD, 2, 1); + auto add4 = builder.AddNode("add4", ADD, 2, 1); + auto add5 = builder.AddNode("add5", ADD, 2, 1); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(prefetch1, 0, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(prefetch2, 0, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(prefetch3, 0, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(prefetch4, 0, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add1, 0, net_output, 0); + builder.AddDataEdge(add2, 0, net_output, 1); + builder.AddDataEdge(add3, 0, net_output, 2); + builder.AddDataEdge(add4, 0, net_output, 3); + builder.AddDataEdge(add5, 0, net_output, 4); + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} + +/// +/// GraphWithMultiOutputPrefetch: Prefetch has more than one output +/// +/// w1 w2 w3 w4 w5 +/// \ / \ / \ / \ / \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 +/// / \ / \ / \ / \ / +/// / \ / \ / \ / \ / +/// const1 ----- add1 add2 add3 add4 add5 +/// | \ | / | +/// | \ | / | +/// | \ | / | +/// | \ | / | +/// -------------- net_output --------------- +/// +/// Memory distribution: +/// +/// |___w1__|__w2__|__w3__|__| +/// +/// |_____w4_____|_____w5____| +/// +ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiInputOutputPrefetch() { + auto builder = InnerGraphBuilder(graph_name_); + auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); + auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); + auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); + auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); + auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); + + const int64_t buffer_pool_id = 0; + const int64_t buffer_pool_size = 5600; + + auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 2, 2); + SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500, 500}); + auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 2, 2); + SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500, 500}); + auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 2, 2); + SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500, 1024}); + auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 2, 2); + SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024, 1024}); + auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); + SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); + + auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto add2 = builder.AddNode("add2", ADD, 2, 1); + auto add3 = builder.AddNode("add3", ADD, 2, 1); + auto add4 = builder.AddNode("add4", ADD, 2, 1); + auto add5 = builder.AddNode("add5", ADD, 2, 1); + auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0); + + builder.AddDataEdge(w1, 0, prefetch1, 0); + builder.AddDataEdge(w2, 0, prefetch1, 1); + builder.AddDataEdge(w2, 0, prefetch2, 0); + builder.AddDataEdge(w3, 0, prefetch2, 1); + builder.AddDataEdge(w3, 0, prefetch3, 0); + builder.AddDataEdge(w4, 0, prefetch3, 1); + builder.AddDataEdge(w4, 0, prefetch4, 0); + builder.AddDataEdge(w5, 0, prefetch4, 1); + builder.AddDataEdge(w5, 0, prefetch5, 0); + + builder.AddDataEdge(const1, 0, add1, 0); + builder.AddDataEdge(prefetch1, 0, add1, 1); + + builder.AddDataEdge(prefetch1, 1, add2, 0); + builder.AddDataEdge(prefetch2, 0, add2, 1); + + builder.AddDataEdge(prefetch2, 1, add3, 0); + builder.AddDataEdge(prefetch3, 0, add3, 1); + + builder.AddDataEdge(prefetch3, 1, add4, 0); + builder.AddDataEdge(prefetch4, 0, add4, 1); + + builder.AddDataEdge(prefetch4, 1, add5, 0); + builder.AddDataEdge(prefetch5, 0, add5, 1); + + builder.AddDataEdge(add1, 0, net_output, 0); + builder.AddDataEdge(add2, 0, net_output, 1); + builder.AddDataEdge(add3, 0, net_output, 2); + builder.AddDataEdge(add4, 0, net_output, 3); + builder.AddDataEdge(add5, 0, net_output, 4); + + auto compute_graph = builder.GetGraph(); + + return compute_graph; +} +} // namespace ut +} // namespace ge diff --git a/tests/ut/ge/graph/utils/buffer_pool_graph_builder.h b/tests/ut/ge/graph/utils/buffer_pool_graph_builder.h new file mode 100644 index 00000000..24382dd2 --- /dev/null +++ b/tests/ut/ge/graph/utils/buffer_pool_graph_builder.h @@ -0,0 +1,279 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_ +#define GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_ + +#include +#include + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/node.h" + +namespace ge { +namespace ut { +class BufferPoolGraphBuilder { + public: + explicit BufferPoolGraphBuilder(const std::string &name = "BufferPoolGraph"); + ~BufferPoolGraphBuilder() {} + class InnerGraphBuilder { + public: + explicit InnerGraphBuilder(const std::string &name); + ~InnerGraphBuilder() {} + NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, + Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, + std::vector shape = {1, 1, 224, 224}); + + void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); + + void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); + + ComputeGraphPtr GetGraph() { + graph_->TopologicalSorting(); + return graph_; + } + private: + ComputeGraphPtr graph_; + }; + + /// + /// Normal graph + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// \ \ \ \ \ + /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output + /// + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildNormalGraph(); + + /// + /// Normal graph with multi buffer pool + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// (pool0) (pool1) (pool0) (pool0) (pool1) + /// \ \ \ \ \ + /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output + /// + /// + /// Memory distribution: + /// + /// |___w1__|__w3__|_________| + /// |_____w4_____|___________| + /// + /// |___w2__|_____w5___|_____| + /// + ComputeGraphPtr BuildNormalGraphWithMultiBufferPool(); + + /// + /// SerialGraph: Buffer pool size only can contain one prefetch node + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// \ \ \ \ \ + /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output + /// + /// + /// Memory distribution: + /// + /// |____w1_____|__| + /// + /// |____w2_____|__| + /// + /// |____w3_____|__| + /// + /// |______w4______| + /// + /// |______w5______| + /// + ComputeGraphPtr BuildSerialGraph(); + + /// + /// GraphWithMultiPrefetch: Calc node with more prefetch node + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 + /// \ / \ / \ / + /// \ / \ / \ / + /// \ / \ / \ / + /// add1 ------ c ------- add2 ----- c ----- add3 + /// | | | + /// | | | + /// --------------- net_output ------------ + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildGraphWithMultiPrefetch(); + + /// + /// GraphWithSubgraph: Calc node in different subgraph + /// + /// + /// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output + /// + /// + /// Subgraph1: Subgraph2: + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// \ \ \ \ \ + /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out + /// + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildGraphWithSubgraph(); + + /// + /// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency + /// + /// + /// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output + /// + /// + /// Subgraph1: Subgraph2: + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// \ \ \ \ \ + /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out + /// + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildSubgraphWithInnerDependency(); + + /// + /// BuildGraphWithMultiBatch: Different batch label + /// + /// + /// batch_label_128 + /// + /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- + /// / / / / / / \ + /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \ + /// const1 switch_false / / / / / \ + /// \ / / / / / / \ + /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output + /// / \ \ \ \ \ \ / + /// const2 switch_true \ \ \ \ \ / + /// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 / + /// \ \ \ \ \ \ / + /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- + /// + /// batch_label_256 + /// + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildGraphWithMultiBatch(); + + /// + /// GraphWithMultiOutputPrefetch: Prefetch has more than one output + /// + /// w1 w2 w3 w4 w5 + /// \ \ \ \ \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// / \ / \ / \ / \ / + /// / \ / \ / \ / \ / + /// const1 ----- add1 add2 add3 add4 add5 + /// | \ | / | + /// | \ | / | + /// | \ | / | + /// | \ | / | + /// -------------- net_output --------------- + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildGraphWithMultiOutputPrefetch(); + + /// + /// GraphWithMultiOutputPrefetch: Prefetch has more than one output + /// + /// w1 w2 w3 w4 w5 + /// \ / \ / \ / \ / \ + /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 + /// / \ / \ / \ / \ / + /// / \ / \ / \ / \ / + /// const1 ----- add1 add2 add3 add4 add5 + /// | \ | / | + /// | \ | / | + /// | \ | / | + /// | \ | / | + /// -------------- net_output --------------- + /// + /// Memory distribution: + /// + /// |___w1__|__w2__|__w3__|__| + /// + /// |_____w4_____|_____w5____| + /// + ComputeGraphPtr BuildGraphWithMultiInputOutputPrefetch(); + + void SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, const std::string &batch_label = ""); + + void SetBatchLabel(NodePtr &node, const std::string &batch_label = ""); + + void SetOutputMemSize(NodePtr &node, const std::vector &mem_size = {1024}); + + void SetWorkSpaceMemSize(NodePtr &node, const std::vector &ws_bytes = {1024}); + + void SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size, + const std::vector &mem_size = {1024}, + const std::vector &ws_bytes = {1024}, + const std::string &batch_label = ""); + + private: + std::string graph_name_; +}; +} // namespace ut +} // namespace ge + +#endif // GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_