Browse Source

buffer pool memory allocator

tags/v1.3.0
wangzhengjun 3 years ago
parent
commit
2561188d96
25 changed files with 3868 additions and 9 deletions
  1. +4
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/ge_inference.mk
  3. +1
    -0
      ge/ge_runner.mk
  4. +10
    -1
      ge/graph/build/memory/block_mem_assigner.cc
  5. +2
    -0
      ge/graph/build/memory/block_mem_assigner.h
  6. +234
    -0
      ge/graph/build/memory/buffer_pool_mem_assigner.cc
  7. +83
    -0
      ge/graph/build/memory/buffer_pool_mem_assigner.h
  8. +52
    -0
      ge/graph/build/memory/graph_mem_assigner.cc
  9. +2
    -0
      ge/graph/build/memory/graph_mem_assigner.h
  10. +1
    -0
      ge/graph/build/memory/module.mk
  11. +4
    -1
      ge/graph/build/run_context.cc
  12. +220
    -4
      ge/graph/build/stream_allocator.cc
  13. +4
    -0
      ge/graph/build/stream_allocator.h
  14. +40
    -0
      ge/graph/common/omg_util.cc
  15. +21
    -0
      ge/graph/common/omg_util.h
  16. +7
    -3
      ge/graph/load/model_manager/davinci_model.cc
  17. +7
    -0
      ge/graph/manager/graph_manager.cc
  18. +574
    -0
      ge/graph/passes/buffer_pool_memory_pass.cc
  19. +136
    -0
      ge/graph/passes/buffer_pool_memory_pass.h
  20. +5
    -0
      tests/depends/runtime/src/runtime_stub.cc
  21. +5
    -0
      tests/ut/ge/CMakeLists.txt
  22. +607
    -0
      tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc
  23. +591
    -0
      tests/ut/ge/graph/passes/buffer_pool_memory_pass_unittest.cc
  24. +978
    -0
      tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc
  25. +279
    -0
      tests/ut/ge/graph/utils/buffer_pool_graph_builder.h

+ 4
- 0
ge/CMakeLists.txt View File

@@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST
"graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/parallel_group_pass.cc"
"graph/passes/set_input_output_offset_pass.cc"
"graph/passes/buffer_pool_memory_pass.cc"
"graph/preprocess/graph_preprocess.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
"graph/preprocess/insert_op/util_insert_aipp_op.cc"
@@ -407,6 +408,7 @@ set(TRAIN_SRC_LIST
"graph/build/memory/hybrid_mem_assigner.cc"
"graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc"
)

set(INFER_SRC_LIST
@@ -617,6 +619,7 @@ set(INFER_SRC_LIST
"graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/set_input_output_offset_pass.cc"
"graph/passes/parallel_group_pass.cc"
"graph/passes/buffer_pool_memory_pass.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"
@@ -680,6 +683,7 @@ set(INFER_SRC_LIST
"graph/build/memory/hybrid_mem_assigner.cc"
"graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc"
)

if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)


+ 1
- 0
ge/ge_inference.mk View File

@@ -222,6 +222,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/hccl_group_pass.cc \
graph/passes/memcpy_addr_async_pass.cc \
graph/passes/set_input_output_offset_pass.cc \
graph/passes/buffer_pool_memory_pass.cc \

OMG_DEVICE_SRC_FILES := $(OMG_HOST_SRC_FILES)



+ 1
- 0
ge/ge_runner.mk View File

@@ -246,6 +246,7 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/end_of_sequence_add_control_pass.cc \
graph/passes/memcpy_addr_async_pass.cc \
graph/passes/set_input_output_offset_pass.cc \
graph/passes/buffer_pool_memory_pass.cc \
graph/preprocess/graph_preprocess.cc \
graph/preprocess/insert_op/ge_aipp_op.cc \
graph/preprocess/insert_op/util_insert_aipp_op.cc \


+ 10
- 1
ge/graph/build/memory/block_mem_assigner.cc View File

@@ -1655,6 +1655,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector
bool is_atomic = false;
// If GetBool fail, is_atomic is false.
(void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic);
bool is_buffer_pool_mem_supported = (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID)) &&
(op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) && (!root_unknown_shape_flag_);
// Allocate memory for the current node and release node memory of the same size in the workspace
GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1",
for (auto iter = stream_workspace_blocks_.begin(); iter != stream_workspace_blocks_.end();
@@ -1694,7 +1696,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector
GE_IF_BOOL_EXEC(!no_need_assign_memory,
no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input););
}
no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node));
no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node) || is_buffer_pool_mem_supported);
if (no_need_assign_memory) {
zero_memory_list_.emplace_back(node, kOutput, i, false);
continue;
@@ -1740,6 +1742,13 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM);
GE_IF_BOOL_EXEC(op_no_reuse_mem != nullptr, op_no_reuse_mem_str = string(op_no_reuse_mem);
CheckAndGetOpReuseEnv(op_no_reuse_mem_str, op_no_reuse_mem_vec_, op_reuse_env_valid_););
auto root_graph = GraphUtils::FindRootGraph(compute_graph_);
if (root_graph == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][RootGraph]Root graph is nullptr, graph:%s.", compute_graph_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Root graph is nullptr, graph:%s.", compute_graph_->GetName().c_str());
return;
}
root_unknown_shape_flag_ = root_graph->GetGraphUnknownFlag();

for (NodePtr &n : compute_graph_->GetAllNodes()) {
auto node_op_desc = n->GetOpDesc();


+ 2
- 0
ge/graph/build/memory/block_mem_assigner.h View File

@@ -494,6 +494,8 @@ class BlockMemAssigner : public MemAssigner {
/// @ [stream2][nodeid]
///
DependStreamLife total_node_depend_stream_life_;

bool root_unknown_shape_flag_ = false;
};
} // namespace ge
#endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_

+ 234
- 0
ge/graph/build/memory/buffer_pool_mem_assigner.cc View File

@@ -0,0 +1,234 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/build/memory/buffer_pool_mem_assigner.h"
#include "graph/common/omg_util.h"
#include "graph/utils/tensor_utils.h"
#include "framework/common/util.h"
#include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h"
#include "common/math/math_util.h"
#include "common/util/error_manager/error_manager.h"

namespace ge {
namespace {
const size_t kBufferPoolNodeMemInfoLength = 2;
const uint32_t kBufferPoolNodeOutputSizeIndex = 0;
const uint32_t kBufferPoolNodeOutputOffsetIndex = 1;
} // namespace

Status BufferPoolMemAssigner::Assign() {
if (compute_graph_ == nullptr) {
GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr");
REPORT_INNER_ERROR("E19999", "Input graph is nullptr");
return PARAM_INVALID;
}
Status ret = InitAssigner(compute_graph_);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Init][Assigner]Graph:%s.", compute_graph_->GetName().c_str());
return FAILED;
}
ret = AssignOutput();
if (ret != SUCCESS) {
GELOGE(FAILED, "[Assign][Output]Graph:%s.", compute_graph_->GetName().c_str());
return FAILED;
}
return SUCCESS;
}

Status BufferPoolMemAssigner::GetOutputMemoryType(const NodePtr &node, size_t idx, int64_t &memory_type) {
GE_CHECK_NOTNULL(node->GetOpDesc());
memory_type = RT_MEMORY_HBM;
std::vector<int64_t> type_list;
bool has_mem_type = ge::AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list);
if (has_mem_type && (type_list.size() != node->GetOpDesc()->GetOutputsSize() || idx >= type_list.size())) {
GELOGE(PARAM_INVALID, "[Check][OutputParam]Output param invalid, output size:%zu, mem type size:%zu, index:%zu.",
node->GetOpDesc()->GetOutputsSize(), type_list.size(), idx);
REPORT_INNER_ERROR("E19999", "Output param invalid, output size:%zu, mem type size:%zu, index:%zu.",
node->GetOpDesc()->GetOutputsSize(), type_list.size(), idx);
return PARAM_INVALID;
}
memory_type = has_mem_type ? type_list[idx] : RT_MEMORY_HBM;
return SUCCESS;
}

Status BufferPoolMemAssigner::InitAssigner(const ComputeGraphPtr &graph) {
for (const NodePtr &node : graph->GetAllNodes()) {
int64_t buffer_pool_id = 0;
int64_t buffer_pool_size = 0;
bool get_attr = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id);
get_attr = get_attr && (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size));
if (get_attr) {
std::string batch_label;
(void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
buffer_pool_nodes_[batch_label][buffer_pool_id].emplace_back(node);
auto iter = buffer_pool_size_[batch_label].find(buffer_pool_id);
if (iter == buffer_pool_size_[batch_label].end()) {
buffer_pool_size_[batch_label][buffer_pool_id] = buffer_pool_size;
}
Status ret = InitMemOffsetBase(node);
if (ret != SUCCESS) {
GELOGE(ret, "[Init][MemOffsetBase]Batch label:%s.", batch_label.c_str());
REPORT_INNER_ERROR("E19999", "Failed to init offset base, batch label:%s.", batch_label.c_str());
return ret;
}
}
}

int64_t max_size = 0;
for (const auto &iter : buffer_pool_size_) {
std::string batch_label = iter.first;
int64_t batch_offset = mem_offset_base_;
for (const auto &buffer_pool : iter.second) {
int64_t buffer_pool_id = buffer_pool.first;
int64_t buffer_pool_size = buffer_pool.second;
buffer_pool_offset_base_[batch_label][buffer_pool_id] = batch_offset;
FMK_INT64_ADDCHECK(buffer_pool_size, kBufferPoolMemAlignSize);
AlignMemSize(buffer_pool_size, kBufferPoolMemAlignSize);
FMK_INT64_ADDCHECK(batch_offset, (buffer_pool_size + kBufferPoolMemAlignSize));
batch_offset += (buffer_pool_size + kBufferPoolMemAlignSize);
}
int64_t batch_mem_size = batch_offset - mem_offset_base_;
GELOGI("[Init][Assigner]Get batch mem size, batch label:%s, mem size:%ld.", batch_label.c_str(), batch_mem_size);
if (max_size < batch_mem_size) {
max_size = batch_mem_size;
}
}
FMK_INT64_ADDCHECK(mem_offset_base_, max_size);
mem_offset_ = static_cast<size_t>(mem_offset_base_ + max_size);
GELOGI("[Init][Assigner]Init buffer pool mem assigner successfully, "
"mem type:%ld, mem offset base:%ld, mem offset:%zu.", mem_type_, mem_offset_base_, mem_offset_);
return SUCCESS;
}

Status BufferPoolMemAssigner::InitMemOffsetBase(const NodePtr &node) {
int64_t mem_type;
Status ret = GetOutputMemoryType(node, static_cast<size_t>(kBufferPoolNodeOutIndex), mem_type);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][MemType]Node:%s, index:%u.", node->GetName().c_str(), kBufferPoolNodeOutIndex);
REPORT_INNER_ERROR("E19999", "Failed to get output memory type, node:%s, index:%u.",
node->GetName().c_str(), kBufferPoolNodeOutIndex);
return ret;
}
if (mem_type_ != mem_type && init_offset_base_) {
GELOGE(PARAM_INVALID, "[Check][MemType]The memory type of all buffer pool nodes must be the same, node:%s, "
"required:%ld, actually: %ld", node->GetName().c_str(), mem_type_, mem_type);
REPORT_INNER_ERROR("E19999", "The memory type of all buffer pool nodes must be the same, node:%s, "
"required:%ld, actually: %ld", node->GetName().c_str(), mem_type_, mem_type);
return PARAM_INVALID;
}
if (!init_offset_base_) {
auto iter = mem_type_to_offset_.find(mem_type);
if (iter == mem_type_to_offset_.end()) {
GELOGE(PARAM_INVALID, "[Check][MemType]Memory type is not supported, node:%s, mem type:%ld.",
node->GetName().c_str(), mem_type);
REPORT_INNER_ERROR("E19999", "Memory type is not supported, node:%s, mem type:%ld.",
node->GetName().c_str(), mem_type);
return PARAM_INVALID;
}
mem_offset_base_ = static_cast<int64_t>(iter->second);
FMK_INT64_ADDCHECK(mem_offset_base_, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize));
AlignMemSize(mem_offset_base_, kBufferPoolMemAlignSize);
// The HCOM nodes may access the previous 512 bytes.
mem_offset_base_ += kBufferPoolMemAlignSize;
mem_type_ = mem_type;
init_offset_base_ = true;
GELOGI("[Init][MemOffsetBase]Init offset base:%ld, memory type:%ld", mem_offset_base_, mem_type);
}
return SUCCESS;
}

Status BufferPoolMemAssigner::AssignOutput() {
for (auto &batch_pool_nodes_map : buffer_pool_nodes_) {
std::string batch_label = batch_pool_nodes_map.first;
for (auto &pool_nodes_map : batch_pool_nodes_map.second) {
int64_t buffer_pool_id = pool_nodes_map.first;
auto iter_buffer_id_size = buffer_pool_size_[batch_label].find(buffer_pool_id);
if (iter_buffer_id_size == buffer_pool_size_[batch_label].end()) {
GELOGE(INTERNAL_ERROR, "[Get][BufferPoolSize]Pool id:%ld.", buffer_pool_id);
REPORT_INNER_ERROR("E19999", "Failed to get buffer pool size, pool id:%ld.", buffer_pool_id);
return INTERNAL_ERROR;
}
auto iter_buffer_id_offset = buffer_pool_offset_base_[batch_label].find(buffer_pool_id);
if (iter_buffer_id_offset == buffer_pool_offset_base_[batch_label].end()) {
GELOGE(INTERNAL_ERROR, "[Get][BufferPoolBaseOffset]Pool id:%ld.", buffer_pool_id);
REPORT_INNER_ERROR("E19999", "Failed to get buffer pool base offset, pool id:%ld.", buffer_pool_id);
return INTERNAL_ERROR;
}
int64_t buffer_pool_size = iter_buffer_id_size->second;
int64_t output_offset_base = iter_buffer_id_offset->second;
Status ret = AssignOutputInOneBufferPool(batch_label, output_offset_base, pool_nodes_map.second);
if (ret != SUCCESS) {
GELOGE(ret, "[Assign][OneBufferPool]Batch label:%s, pool id:%ld, pool size:%ld, offset base:%ld.",
batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base);
REPORT_INNER_ERROR("E19999", "Failed to assign output memory, batch label:%s, "
"pool id:%ld, pool size:%ld, offset base:%ld.",
batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base);
return ret;
}
GELOGI("[Assign][Output]Assign output successfully, batch label:%s, pool id:%ld, pool size:%ld, offset base:%ld.",
batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base);
}
}
return SUCCESS;
}

Status BufferPoolMemAssigner::AssignOutputInOneBufferPool(const std::string &batch_label,
int64_t output_offset_base,
const std::vector<NodePtr> &buffer_pool_nodes) {
for (const NodePtr &node : buffer_pool_nodes) {
int64_t output_size = 0;
Status ret = GetMemorySize(node, output_size);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][MemSize]Node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str());
return ret;
}
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
vector<int64_t> memory_size_and_offset;
bool get_attr = AttrUtils::GetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, memory_size_and_offset);
if (!get_attr || memory_size_and_offset.size() != kBufferPoolNodeMemInfoLength) {
GELOGE(PARAM_INVALID, "[Get][Attr]Node:%s, mem info size:%zu, required size:%zu.",
node->GetName().c_str(), memory_size_and_offset.size(), kBufferPoolNodeMemInfoLength);
REPORT_INNER_ERROR("E19999", "Failed to get pool node memory info, node:%s, info size:%zu, required size:%zu.",
node->GetName().c_str(), memory_size_and_offset.size(), kBufferPoolNodeMemInfoLength);
return PARAM_INVALID;
}
if (output_size != memory_size_and_offset[kBufferPoolNodeOutputSizeIndex]) {
GELOGE(PARAM_INVALID, "[Check][MemSize]Something wrong with memory size, pre size:%ld, curr size:%ld, node:%s.",
memory_size_and_offset[kBufferPoolNodeOutputSizeIndex], output_size, node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Something wrong with memory size, pre size:%ld, curr size:%ld, node:%s.",
memory_size_and_offset[kBufferPoolNodeOutputSizeIndex], output_size, node->GetName().c_str());
return PARAM_INVALID;
}

int64_t logical_offset = memory_size_and_offset[kBufferPoolNodeOutputOffsetIndex];
vector<int64_t> output_list = {(output_offset_base + logical_offset)};
op_desc->SetOutputOffset(output_list);
// log for IMAS tools
GELOGI("[IMAS]Set %s name[%s] optype[%s] %s[%u] offset to [%ld] streamid[%ld] memtype[%ld] "
"size[%zu] realsize[%zu] noalignsize[%zu] life time begin[%d] life time end[%d] "
"child[%d:%d:%d:%d:%d] isref[%d] batch[%s]",
compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
"output", kBufferPoolNodeOutIndex, output_list[kBufferPoolNodeOutIndex], op_desc->GetStreamId(), mem_type_,
static_cast<size_t>(output_size), static_cast<size_t>(output_size), static_cast<size_t>(output_size),
0, 0, 0, 0, 0, 0, 0, 0, batch_label.c_str());
}
return SUCCESS;
}

} // namespace ge

+ 83
- 0
ge/graph/build/memory/buffer_pool_mem_assigner.h View File

@@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_
#define GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_

#include <vector>
#include <map>
#include <unordered_map>
#include "graph/build/memory/mem_assigner.h"
#include "runtime/mem.h"

namespace ge {
class BufferPoolMemAssigner : public MemAssigner {
public:
BufferPoolMemAssigner(ComputeGraphPtr compute_graph, const std::map<int64_t, size_t> &mem_type_to_offset)
: MemAssigner(), compute_graph_(compute_graph),
mem_type_(0),
mem_offset_(0),
mem_offset_base_(0),
init_offset_base_(false),
mem_type_to_offset_(mem_type_to_offset) {}

BufferPoolMemAssigner(const BufferPoolMemAssigner &) = delete;

BufferPoolMemAssigner &operator=(const BufferPoolMemAssigner &) = delete;

~BufferPoolMemAssigner() override = default;

Status Assign() override;

size_t GetMemOffset() const { return mem_offset_; }

int64_t GetMemType() const { return mem_type_; }

private:
static Status GetOutputMemoryType(const NodePtr &node, size_t idx, int64_t &memory_type);

Status InitAssigner(const ComputeGraphPtr &graph);

Status InitMemOffsetBase(const NodePtr &node);

Status AssignOutput();

Status AssignOutputInOneBufferPool(const std::string &batch_label,
int64_t output_offset_base,
const std::vector<NodePtr> &buffer_pool_nodes);

ComputeGraphPtr compute_graph_;

int64_t mem_type_;

size_t mem_offset_;

int64_t mem_offset_base_;

bool init_offset_base_;

std::map<int64_t, size_t> mem_type_to_offset_;

// Use map to ensure that each visit is in the order of pool id
std::unordered_map<std::string, std::map<int64_t, std::vector<NodePtr>>> buffer_pool_nodes_;

// Use map to ensure that each visit is in the order of pool id
std::unordered_map<std::string, std::map<int64_t, int64_t>> buffer_pool_size_;

std::unordered_map<std::string, std::unordered_map<int64_t, int64_t>> buffer_pool_offset_base_;
};
} // namespace ge
#endif // GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_

+ 52
- 0
ge/graph/build/memory/graph_mem_assigner.cc View File

@@ -30,6 +30,7 @@
#include "graph/manager/graph_var_manager.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/build/memory/buffer_pool_mem_assigner.h"

namespace {
const int kAllInputAddrIsAtomic = -1;
@@ -231,6 +232,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map<int64_t, size

GE_CHK_STATUS_RET(ReAssignContinuousMemory(is_loop_graph), "ReAssignContinuousMemory Failed!");
GE_CHK_STATUS_RET(ReAssignAtomicMemory(is_loop_graph), "ReAssignAtomicMemory Failed!");
GE_CHK_STATUS_RET(AssignBufferPoolMemory(), "AssignBufferPoolMemory Failed!");

size_t total_mem_offset = 0;
for (auto pair : memory_offset_) {
@@ -1735,4 +1737,54 @@ ge::Status GraphMemoryAssigner::AssignContinuousInputMemoryWithAtomicProcess(con
return ge::SUCCESS;
}

Status GraphMemoryAssigner::AssignBufferPoolMemory() {
auto is_buffer_pool_mem_enable = [] (const ComputeGraphPtr &graph) -> bool {
for (NodePtr &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
bool has_attrs = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE);
if (has_attrs) {
return true;
}
}
return false;
};
auto root_graph = GraphUtils::FindRootGraph(compute_graph_);
GE_CHECK_NOTNULL(root_graph);
if (root_graph->GetGraphUnknownFlag()) {
GELOGI("[Check][Enable]Unknown root graph does not support buffer pool memory, graph:%s.",
compute_graph_->GetName().c_str());
return SUCCESS;
}
if (!is_buffer_pool_mem_enable(compute_graph_)) {
GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", compute_graph_->GetName().c_str());
return SUCCESS;
}
map<int64_t, size_t> mem_type_to_offset;
for (const auto &pair : memory_offset_) {
mem_type_to_offset[pair.first] = pair.second.mem_offset_;
}
BufferPoolMemAssigner buffer_pool_mem_assigner(compute_graph_, mem_type_to_offset);
Status status = buffer_pool_mem_assigner.Assign();
if (status != SUCCESS) {
GELOGE(status, "[Assign][BufferPoolMem]Graph:%s.", compute_graph_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to assign buffer pool memory, graph:%s.", compute_graph_->GetName().c_str());
return status;
}
int64_t mem_type = buffer_pool_mem_assigner.GetMemType();
auto iter = memory_offset_.find(mem_type);
if (iter == memory_offset_.end()) {
GELOGE(FAILED, "[Check][MemType]Memory type is not supported, graph:%s, mem type:%ld.",
compute_graph_->GetName().c_str(), mem_type);
REPORT_INNER_ERROR("E19999", "Memory type is not supported, graph:%s, mem type:%ld.",
compute_graph_->GetName().c_str(), mem_type);
return FAILED;
}
iter->second.mem_offset_ = buffer_pool_mem_assigner.GetMemOffset();
GELOGI("[Assign][BufferPoolMem]Assign buffer pool memory successfully, graph:%s, mem type:%ld, mem offset:%zu.",
compute_graph_->GetName().c_str(), mem_type, buffer_pool_mem_assigner.GetMemOffset());
return SUCCESS;
}
} // namespace ge

+ 2
- 0
ge/graph/build/memory/graph_mem_assigner.h View File

@@ -188,6 +188,8 @@ class GraphMemoryAssigner {

void PrintMemoryOffset();

Status AssignBufferPoolMemory();

MemoryOffsetMap memory_offset_;
ge::ComputeGraphPtr compute_graph_;
HybridMemAssignerPtr mem_assigner_;


+ 1
- 0
ge/graph/build/memory/module.mk View File

@@ -8,6 +8,7 @@ local_lib_src_files := memory_assigner.cc \
hybrid_mem_assigner.cc \
max_block_mem_assigner.cc \
var_mem_assign_util.cc \
buffer_pool_mem_assigner.cc \

local_lib_inc_path := ${LOCAL_PATH} \
${TOPDIR}inc \


+ 4
- 1
ge/graph/build/run_context.cc View File

@@ -18,6 +18,7 @@
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/common/omg_util.h"

namespace ge {
RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); }
@@ -88,9 +89,11 @@ Status RunContextUtil::CreateRtModelResources(uint32_t stream_num, uint32_t even
}

// Create rt event
uint32_t create_flag = static_cast<uint32_t>((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG :
RT_EVENT_DEFAULT);
for (uint32_t i = 0; i < event_num; ++i) {
rtEvent_t event = nullptr;
rt_ret = rtEventCreate(&event);
rt_ret = rtEventCreateWithFlag(&event, create_flag);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "call rtEventCreate fail, ret:%d, index:%u, when %s",
static_cast<int>(rt_ret), i, __FUNCTION__);


+ 220
- 4
ge/graph/build/stream_allocator.cc View File

@@ -27,6 +27,8 @@
#include "graph/ge_context.h"
#include "graph/utils/graph_utils.h"
#include "init/gelib.h"
#include "common/string_util.h"
#include "common/util/error_manager/error_manager.h"

using std::map;
using std::set;
@@ -38,6 +40,13 @@ const int64_t kTaskNumPerNormalNode = 3;
const int64_t kTaskNumPerHcclNode = 245;
const char *const kTrueStr = "true";
const char *const kFalseStr = "false";
const size_t kEventMultiplexingItemCount = 3;
const size_t kKeyWordIndex = 0;
const size_t kNodeNameIndex = 1;
const size_t kEventIdIndex = 2;
const char *const kSend = "SendTo";
const char *const kRecv = "RecvFrom";
const char kDelim = ';';

inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &continuous_stream_label) {
if (ge::AttrUtils::GetStr(op_desc, ge::ATTR_NAME_CONTINUOUS_STREAM_LABEL, continuous_stream_label)) {
@@ -52,6 +61,97 @@ bool IsHcclOp(const string &op_type) {
ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE});
return hccl_op_types.find(op_type) != hccl_op_types.end();
}

ge::Status ParseNodeEventMultiplexing(const ge::NodePtr &node,
const std::vector<std::string> &raw_event_multiplexing,
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_send,
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_recv) {
GE_CHECK_NOTNULL(node);
for (const auto &str : raw_event_multiplexing) {
std::vector<std::string> ele = ge::StringUtils::Split(str, kDelim);
if (ele.size() != kEventMultiplexingItemCount) {
GELOGE(ge::PARAM_INVALID, "[Check][RawMultiplexing]Size error, node:%s, require size:%zu, actually:%zu.",
node->GetName().c_str(), kEventMultiplexingItemCount, ele.size());
REPORT_INNER_ERROR("E19999", "Raw event multiplexing is invalid, node:%s, require size:%zu, actually:%zu.",
node->GetName().c_str(), kEventMultiplexingItemCount, ele.size());
return ge::PARAM_INVALID;
}
int value;
try {
value = std::stoi(ele[kEventIdIndex]);
} catch (std::invalid_argument &) {
GELOGE(ge::PARAM_INVALID, "[Throw][Exception]Event id is invalid, node:%s, raw:%s.",
node->GetName().c_str(), ele[kEventIdIndex].c_str());
REPORT_INNER_ERROR("E19999", "Event id is invalid, node:%s, raw:%s.",
node->GetName().c_str(), ele[kEventIdIndex].c_str());
return ge::PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(ge::PARAM_INVALID, "[Throw][Exception]Event id is out of range, node:%s, raw:%s.",
node->GetName().c_str(), ele[kEventIdIndex].c_str());
REPORT_INNER_ERROR("E19999", "Event id is out of range, node:%s, raw:%s.",
node->GetName().c_str(), ele[kEventIdIndex].c_str());
return ge::PARAM_INVALID;
}
if (value < 0) {
GELOGE(ge::PARAM_INVALID, "[Check][EventId]Event id is out of range, node:%s, raw:%s, value:%d.",
node->GetName().c_str(), ele[kEventIdIndex].c_str(), value);
REPORT_INNER_ERROR("E19999", "Event id is out of range, node:%s, raw:%s, value:%d.",
node->GetName().c_str(), ele[kEventIdIndex].c_str(), value);
return ge::PARAM_INVALID;
}
if (ele[kKeyWordIndex] == kSend) {
node_to_send[node].emplace_back(std::make_pair(ele[kNodeNameIndex], static_cast<uint32_t>(value)));
} else if (ele[kKeyWordIndex] == kRecv) {
node_to_recv[node].emplace_back(std::make_pair(ele[kNodeNameIndex], static_cast<uint32_t>(value)));
} else {
GELOGE(ge::PARAM_INVALID, "[Check][KeyWord]Key word is not supported, node:%s, key:%s.",
node->GetName().c_str(), ele[kEventIdIndex].c_str());
REPORT_INNER_ERROR("E19999", "Key word is not supported, node:%s, key:%s.",
node->GetName().c_str(), ele[kEventIdIndex].c_str());
return ge::PARAM_INVALID;
}
}
return ge::SUCCESS;
}

ge::Status ParseAllNodeEventMultiplexing(const ge::ComputeGraphPtr &graph,
std::unordered_map<std::string, ge::NodePtr> &name_to_node_map,
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_send,
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_recv) {
for (const auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
ge::OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
name_to_node_map.insert({node->GetName(), node});
std::vector<std::string> raw_event_multiplexing;
if (!(op_desc->HasAttr(ge::ATTR_NAME_EVENT_MULTIPLEXING))) {
continue;
}
bool get_attr = ge::AttrUtils::GetListStr(op_desc, ge::ATTR_NAME_EVENT_MULTIPLEXING, raw_event_multiplexing);
if (!get_attr) {
GELOGE(ge::PARAM_INVALID, "[Get][Attr]Node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to get raw event multiplexing, node:%s.", node->GetName().c_str());
return ge::PARAM_INVALID;
}
auto parse_ret = ParseNodeEventMultiplexing(node, raw_event_multiplexing, node_to_send, node_to_recv);
if (parse_ret != ge::SUCCESS) {
GELOGE(parse_ret, "[Parse][Eventmultiplexing]Node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to parse node event multiplexing, node:%s.", node->GetName().c_str());
return parse_ret;
}
}
return ge::SUCCESS;
}

std::vector<uint32_t> GetIntersection(std::vector<uint32_t> &a, std::vector<uint32_t> &b) {
std::unordered_set<uint32_t> ele_of_a(a.begin(), a.end());
std::vector<uint32_t> res;
for (auto &ele : b) {
if (ele_of_a.count(ele) > 0) {
res.emplace_back(ele);
}
}
return res;
}
} // namespace

namespace ge {
@@ -150,6 +250,12 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu
return status;
}

status = RefreshEventsWithReuse();
if (status != SUCCESS) {
GELOGE(status, "[Refresh][Events]RefreshEventsWithReuse failed!");
return status;
}

status = InsertSyncEventNodes();
if (status != SUCCESS) {
GELOGE(status, "InsertSyncEventNode failed!");
@@ -1161,6 +1267,94 @@ Status StreamAllocator::CheckStreamActived() const {
return SUCCESS;
}

Status StreamAllocator::ReuseEvent(bool send_to,
const std::unordered_map<std::string, ge::NodePtr> &name_to_node_map,
const std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_event_id) {
for (const auto &node_event_id : node_to_event_id) {
ge::NodePtr curr_node = node_event_id.first;
NodePtr send_node = send_to ? curr_node : nullptr;
NodePtr recv_node = send_to ? nullptr : curr_node;
for (const auto &event_pair : node_event_id.second) {
auto peer_node_iter = name_to_node_map.find(event_pair.first);
if (peer_node_iter == name_to_node_map.end()) {
GELOGE(PARAM_INVALID, "[Get][Node]Name:%s.", event_pair.first.c_str());
REPORT_INNER_ERROR("E19999", "Failed to find node, name:%s.", event_pair.first.c_str());
return PARAM_INVALID;
}
recv_node = send_to ? peer_node_iter->second : recv_node;
send_node = send_to ? send_node : peer_node_iter->second;
GE_CHECK_NOTNULL(send_node);
GE_CHECK_NOTNULL(recv_node);
auto event_id = GetIntersection(node_to_send_events_[send_node], node_to_recv_events_[recv_node]);
uint32_t new_event = event_pair.second + event_num_;
if (event_id.empty()) {
GELOGI("[Check][Optimized]Send:%s, recv:%s.", send_node->GetName().c_str(), recv_node->GetName().c_str());
continue;
} else if (event_id.size() != 1) {
GELOGW("[Check][Event]More than one event are found between %s and %s, event num:%zu.",
send_node->GetName().c_str(), recv_node->GetName().c_str(), event_id.size());
}
uint32_t old_event = event_id[0];
auto reuse_event_id = [] (vector<uint32_t> &event_list, uint32_t old_event, uint32_t new_event) -> void {
event_list.erase(std::remove(event_list.begin(), event_list.end(), old_event), event_list.end());
event_list.push_back(new_event);
return;
};
reuse_event_id(node_to_send_events_[send_node], old_event, new_event);
reuse_event_id(node_to_recv_events_[recv_node], old_event, new_event);
GELOGI("[Reuse][Event]Replace event successfully, send node:%s, recv node:%s, old id:%u, new id:%u.",
send_node->GetName().c_str(), recv_node->GetName().c_str(), old_event, new_event);
}
}
return ge::SUCCESS;
}

// Refresh events to reuse events
Status StreamAllocator::RefreshEventsWithReuse() {
GELOGI("[Refresh][Events]Refresh events with reuse, stream num:%ld, original event num:%u.", stream_num_, event_num_);
if (event_num_ <= kEventReuseThreshold) {
GELOGI("[Check][ReuseThreshold]Event used num is %u, less than %u, skip reuse.",
event_num_, kEventReuseThreshold);
return SUCCESS;
}
std::unordered_map<std::string, NodePtr> name_to_node_map;
std::unordered_map<NodePtr, std::vector<std::pair<std::string, uint32_t>>> node_to_send;
std::unordered_map<NodePtr, std::vector<std::pair<std::string, uint32_t>>> node_to_recv;
Status ret = ParseAllNodeEventMultiplexing(whole_graph_, name_to_node_map, node_to_send, node_to_recv);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][AllNodeEventMultiplexing]Graph:%s.", whole_graph_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to parse all node event multiplexing, graph:%s.",
whole_graph_->GetName().c_str());
return ret;
}
if (node_to_send.empty() && node_to_recv.empty()) {
return SUCCESS;
}

ret = ReuseEvent(true, name_to_node_map, node_to_send);
if (ret != SUCCESS) {
GELOGE(ret, "[Reuse][Event]Phase:Send, graph:%s.", whole_graph_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to reuse event, phase:Send, graph:%s.", whole_graph_->GetName().c_str());
return ret;
}

ret = ReuseEvent(false, name_to_node_map, node_to_recv);
if (ret != SUCCESS) {
GELOGE(ret, "[Reuse][Event]Phase:Recv, graph:%s.", whole_graph_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to reuse event, phase:Recv, graph:%s.", whole_graph_->GetName().c_str());
return ret;
}

Status status = RefreshContinuousEvents();
if (status != SUCCESS) {
GELOGE(status, "[Refresh][ContinuousEvents]Graph:%s.", whole_graph_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to refresh continuous events, graph:%s.", whole_graph_->GetName().c_str());
return status;
}
GELOGI("[Refresh][Events]RefreshEventsWithReuse successfully, event num:%u.", event_num_);
return SUCCESS;
}

// Refresh events to continuous events
Status StreamAllocator::RefreshContinuousEvents() {
// Establish a mapping relationship from old to new event id
@@ -1168,8 +1362,10 @@ Status StreamAllocator::RefreshContinuousEvents() {
uint32_t new_event_id = 0;
for (const auto &one_pair : node_to_send_events_) {
for (const auto &event_id : one_pair.second) {
old_to_new_events[event_id] = new_event_id;
new_event_id++;
if (old_to_new_events.find(event_id) == old_to_new_events.end()) {
old_to_new_events[event_id] = new_event_id;
new_event_id++;
}
}
}

@@ -1208,6 +1404,7 @@ Status StreamAllocator::RefreshContinuousEvents() {

// Insert the real send/recv node in the graph
Status StreamAllocator::InsertSyncEventNodes() {
unordered_map<string, uint32_t> sync_event_name;
for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) {
// Add the node corresponding to the recv event
vector<uint32_t> recv_event_id_list;
@@ -1217,6 +1414,13 @@ Status StreamAllocator::InsertSyncEventNodes() {
GE_CHECK_NOTNULL(node->GetOutControlAnchor());
for (auto &event_id : recv_event_id_list) {
string recv_node_name = whole_graph_->GetName() + "_Recv_" + to_string(event_id);
auto iter = sync_event_name.find(recv_node_name);
if (iter == sync_event_name.end()) {
sync_event_name[recv_node_name] = 1;
} else {
recv_node_name = recv_node_name + "_Reuse_" + to_string(iter->second);
++(iter->second);
}
OpDescPtr op_desc_ptr = MakeShared<OpDesc>(recv_node_name, RECV);
GE_CHECK_NOTNULL(op_desc_ptr);

@@ -1251,6 +1455,13 @@ Status StreamAllocator::InsertSyncEventNodes() {

for (auto &event_id : send_event_id_list) {
string send_node_name = whole_graph_->GetName() + "_Send_" + to_string(event_id);
auto iter = sync_event_name.find(send_node_name);
if (iter == sync_event_name.end()) {
sync_event_name[send_node_name] = 1;
} else {
send_node_name = send_node_name + "_Reuse_" + to_string(iter->second);
++(iter->second);
}
OpDescPtr op_desc_ptr = MakeShared<OpDesc>(send_node_name, SEND);
GE_CHECK_NOTNULL(op_desc_ptr);

@@ -1300,12 +1511,16 @@ void StreamAllocator::DumpEvents() {
GELOGD("After RefreshRealStream: stream %ld.", stream_id);

for (const auto &node : one_pair.second) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
continue;
}
string send_event_str;
for (const auto &send_event_id : node_to_send_events_[node]) {
send_event_str += " " + to_string(send_event_id);
}
if (!send_event_str.empty()) {
GELOGI("node: %s, send events: %s", node->GetName().c_str(), send_event_str.c_str());
GELOGI("node: %s, id: %ld, stream id :%ld, send events: %s.", node->GetName().c_str(),
node->GetOpDesc()->GetId(), node->GetOpDesc()->GetStreamId(), send_event_str.c_str());
}

string recv_event_str;
@@ -1313,7 +1528,8 @@ void StreamAllocator::DumpEvents() {
recv_event_str += " " + to_string(recv_event_id);
}
if (!recv_event_str.empty()) {
GELOGI("node: %s, recv events: %s", node->GetName().c_str(), recv_event_str.c_str());
GELOGI("node: %s, id: %ld, stream id :%ld, recv events: %s.", node->GetName().c_str(),
node->GetOpDesc()->GetId(), node->GetOpDesc()->GetStreamId(), recv_event_str.c_str());
}
}
}


+ 4
- 0
ge/graph/build/stream_allocator.h View File

@@ -71,6 +71,10 @@ class StreamAllocator {
Status SetActiveStreamsForLoop();
Status CheckStreamActived() const;

Status ReuseEvent(bool send_to,
const std::unordered_map<std::string, ge::NodePtr> &name_to_node_map,
const std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_event_id);
Status RefreshEventsWithReuse();
Status RefreshContinuousEvents();

Status InsertSyncEventNodes();


+ 40
- 0
ge/graph/common/omg_util.cc View File

@@ -21,6 +21,8 @@
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "common/math/math_util.h"

namespace ge {
///
@@ -204,4 +206,42 @@ Status SetNextIteration(const ge::NodePtr &node, const std::string &next) {

return SUCCESS;
}

///
/// @brief Align the memory
/// @param [in/out] memory size
/// @param [in] alinment
/// @return void
///
void AlignMemSize(int64_t &mem_size, int64_t align_size) {
if (mem_size <= 0) {
return;
}
mem_size = (mem_size + align_size - 1) / align_size * align_size;
}

///
/// @brief Get memory size from tensor desc
/// @param [in] node
/// @param [out] memory size
/// @return Status
///
Status GetMemorySize(const NodePtr &node, int64_t &output_size) {
GE_CHECK_NOTNULL(node->GetOpDesc());
auto output_op_desc = node->GetOpDesc()->GetOutputDescPtr(kBufferPoolNodeOutIndex);
GE_CHECK_NOTNULL(output_op_desc);
int64_t size = 0;
auto ret = ge::TensorUtils::GetSize(*output_op_desc, size);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][Size]Node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
FMK_INT64_ADDCHECK(size, kBufferPoolMemAlignSize);
AlignMemSize(size, kBufferPoolMemAlignSize);
// The HCOM operator requires an additional 512 bytes before and after
FMK_INT64_ADDCHECK(size, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize));
output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize;
return SUCCESS;
}
} // namespace ge

+ 21
- 0
ge/graph/common/omg_util.h View File

@@ -27,6 +27,11 @@
#include "graph/node.h"

namespace ge {
namespace {
const int64_t kBufferPoolMemAlignSize = 512;
const uint32_t kBufferPoolNodeOutIndex = 0;
const uint32_t kEventReuseThreshold = 65500;
} // namespace
///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
@@ -96,6 +101,22 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node);
/// @return Status
///
Status SetNextIteration(const ge::NodePtr &node, const std::string &next);

///
/// @brief Align the memory
/// @param [in/out] memory size
/// @param [in] alinment
/// @return void
///
void AlignMemSize(int64_t &mem_size, int64_t align_size);

///
/// @brief Get memory size from tensor desc
/// @param [in] node
/// @param [out] memory size
/// @return Status
///
Status GetMemorySize(const NodePtr &node, int64_t &output_size);
} // namespace ge

#endif // GE_GRAPH_COMMON_OMG_UTIL_H_

+ 7
- 3
ge/graph/load/model_manager/davinci_model.cc View File

@@ -60,6 +60,7 @@
#include "securec.h"
#include "graph/common/local_context.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "graph/common/omg_util.h"

// create std::thread, catch exceptions using try/catch
#define CREATE_STD_THREAD(thread_id, func, args) \
@@ -664,9 +665,12 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size
GELOGI("Logical stream index:%u, stream:%p, rtstream: %d.", i, stream, rt_stream_id);
}

for (uint32_t i = 0; i < EventNum(); i++) {
rtEvent_t rt_event;
GE_CHK_RT_RET(rtEventCreate(&rt_event));
uint32_t event_num = EventNum();
uint32_t create_flag = static_cast<uint32_t>((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG :
RT_EVENT_DEFAULT);
for (uint32_t i = 0; i < event_num; ++i) {
rtEvent_t rt_event = nullptr;
GE_CHK_RT_RET(rtEventCreateWithFlag(&rt_event, create_flag));
event_list_.push_back(rt_event);
}



+ 7
- 0
ge/graph/manager/graph_manager.cc View File

@@ -95,6 +95,7 @@
#include "graph/passes/memcpy_addr_async_pass.h"
#include "graph/passes/hccl_continuous_memcpy_pass.h"
#include "graph/passes/parallel_group_pass.h"
#include "graph/passes/buffer_pool_memory_pass.h"
#include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h"
@@ -2528,6 +2529,12 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed.");
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run.");

// Process offset and dependency for buffer pool memory assigner.
GE_TIMESTAMP_START(BufferPoolMemoryPass);
BufferPoolMemoryPass buffer_pool_mem_pass;
GE_CHK_STATUS_RET(buffer_pool_mem_pass.Run(compute_graph), "Failed to process for buffer pool allocator.");
GE_TIMESTAMP_END(BufferPoolMemoryPass, "BufferPoolMemoryPass::Run.");

// Handle parallel group .
GE_TIMESTAMP_START(ParallelGroup);
ParallelGroupPass parallel_group_pass;


+ 574
- 0
ge/graph/passes/buffer_pool_memory_pass.cc View File

@@ -0,0 +1,574 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/buffer_pool_memory_pass.h"

#include <string>
#include <vector>
#include "graph/common/omg_util.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "common/math/math_util.h"

namespace ge {
namespace {
const size_t kBufferPoolNodeInSize = 1;
const size_t kBufferPoolNodeOutSize = 1;
} // namespace

Status BufferPoolMemoryPass::Run(ComputeGraphPtr graph) {
if (graph == nullptr) {
GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr");
REPORT_INNER_ERROR("E19999", "Input graph is nullptr");
return PARAM_INVALID;
}
// The cache prefetching scheme is developed for very large models, which gets the weight data in advance
// and allocates it to a special memory pool. When the large model is dynamic shape, it need to go through
// the executor flow and is not allocated memory statically. This is another development point, so we will
// skip the dynamic shape model processing here.
if (graph->GetParentGraph() != nullptr || graph->GetGraphUnknownFlag()) {
return SUCCESS;
}
if (!IsBufferPoolMemEnable(graph)) {
GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", graph->GetName().c_str());
return SUCCESS;
}
Status ret = graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str());
return ret;
}

ret = CopyOutForMultiUsedOutput(graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Copy][Output]Graph:%s.", graph->GetName().c_str());
return FAILED;
}

ret = GetBufferPoolAndPeerCalcNodes(graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Get][BufferPoolNode]Graph:%s.", graph->GetName().c_str());
return FAILED;
}
if (calc_nodes_.empty()) {
GELOGE(FAILED, "[Check][BufferPoolNode]Graph:%s.", graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "All Buffer pool nodes are isolated nodes in graph:%s.", graph->GetName().c_str());
return FAILED;
}
ret = AllocateAllBufferPoolSpace();
if (ret != SUCCESS) {
GELOGE(FAILED, "[Alloc][BufferPoolMem]Graph:%s.", graph->GetName().c_str());
return FAILED;
}

ret = SetResultOfMemoryAndEvent();
if (ret != SUCCESS) {
GELOGE(FAILED, "[Set][Result]Graph:%s.", graph->GetName().c_str());
return FAILED;
}
ret = graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str());
return ret;
}
return SUCCESS;
}

void BufferPoolMemoryPass::ClearQueue(std::queue<std::pair<std::string, uint32_t>> &q) {
while (!q.empty()) {
q.pop();
}
}

Status BufferPoolMemoryPass::IsBufferPoolMemEnable(const ComputeGraphPtr &graph) {
for (NodePtr &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
if (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) {
return true;
}
}
return false;
}

Status BufferPoolMemoryPass::CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size,
std::unordered_map<int64_t, int64_t> &calc_total_size) {
auto iter = calc_total_size.find(pool_id);
if (iter == calc_total_size.end()) {
calc_total_size[pool_id] = total_size;
} else {
FMK_INT64_ADDCHECK(calc_total_size[pool_id], total_size);
calc_total_size[pool_id] += total_size;
}
if (calc_total_size[pool_id] > buffer_pool_size) {
GELOGE(INTERNAL_ERROR, "[Check][Size]The memory required at the same is greater than buffer pool size, "
"pool id:%ld, pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]);
REPORT_INNER_ERROR("E19999", "The memory required at the same is greater than buffer pool size, pool id:%ld,"
" pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]);
return INTERNAL_ERROR;
}
return SUCCESS;
}

Status BufferPoolMemoryPass::TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool &not_change) {
auto pre_node_graph = pre_node->GetOwnerComputeGraph();
auto curr_node_graph = curr_node->GetOwnerComputeGraph();
std::string pre_node_stream_label;
(void) AttrUtils::GetStr(pre_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, pre_node_stream_label);
std::string curr_node_stream_label;
(void) AttrUtils::GetStr(curr_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, curr_node_stream_label);
not_change = true;
if ((pre_node_graph == curr_node_graph) && (pre_node_stream_label == pre_node_stream_label)) {
// Same subgraph, including simultaneously in the root graph.
auto ret = ge::GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), curr_node->GetInControlAnchor());
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Add][Edge]Src:%s, dst:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to add ctrl edge from %s to %s.",
pre_node->GetName().c_str(), curr_node->GetName().c_str());
return INTERNAL_ERROR;
}
not_change = false;
} else if (pre_node_graph->GetParentGraph() == curr_node_graph->GetParentGraph() &&
pre_node_graph->GetParentNode() != nullptr && curr_node_graph->GetParentNode() != nullptr) {
// Two nodes are located on different child graphs of different parent nodes.
auto pre_node_parent_op_desc = pre_node_graph->GetParentNode()->GetOpDesc();
auto curr_node_parent_op_desc = curr_node_graph->GetParentNode()->GetOpDesc();
GE_CHECK_NOTNULL(pre_node_parent_op_desc);
GE_CHECK_NOTNULL(curr_node_parent_op_desc);
// The parent node dependency is correct to ensure that the child node dependency,
// there is no need to add control edges.
if (pre_node_parent_op_desc->GetId() > curr_node_parent_op_desc->GetId()) {
GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.",
pre_node->GetName().c_str(), curr_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.",
pre_node->GetName().c_str(), curr_node->GetName().c_str());
return INTERNAL_ERROR;
}
GELOGI("[Check][Dependency]The two nodes are located in sub graphs of different parent nodes and meet the "
"dependency relationship. pre:%s, curr:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str());
} else {
GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.",
pre_node->GetName().c_str(), curr_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.",
pre_node->GetName().c_str(), curr_node->GetName().c_str());
return INTERNAL_ERROR;
}
return SUCCESS;
}

Status BufferPoolMemoryPass::InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node) {
auto out_anchor = node->GetOutDataAnchor(kBufferPoolNodeOutIndex);
OpDescBuilder op_desc_builder(node->GetName() + "_memcpy_async", MEMCPYASYNC);
auto mem_copy_op = op_desc_builder.AddInput("x", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex))
.AddOutput("y", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex))
.Build();
std::string batch_label;
bool get_attr = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, batch_label);
if (get_attr && !batch_label.empty()) {
(void) AttrUtils::SetStr(mem_copy_op, ATTR_NAME_STREAM_LABEL, batch_label);
}
auto peer_in_anchors = out_anchor->GetPeerInDataAnchors();
std::vector<InDataAnchorPtr> in_anchors(peer_in_anchors.begin(), peer_in_anchors.end());
if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(mem_copy_op)) != GRAPH_SUCCESS) {
GELOGE(FAILED, "[Insert][Node] Node:%s.", node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str());
return FAILED;
}
return SUCCESS;
}

Status BufferPoolMemoryPass::CopyOutForMultiUsedOutput(ComputeGraphPtr &graph) {
bool changed = false;
for (NodePtr &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
bool use_buffer_pool = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE);
if (use_buffer_pool) {
if ((node->GetInDataNodes().size() == kBufferPoolNodeInSize) &&
(node->GetOutDataNodes().size() == kBufferPoolNodeOutSize)) {
continue;
} else if ((node->GetAllInDataAnchors().size() == kBufferPoolNodeInSize) &&
(node->GetAllOutDataAnchors().size() == kBufferPoolNodeOutSize)) {
// A prefetching output is used in multiple places. Copy one so that the prefetching node remains
// single input and single output.
if (InsertMemCpyNodeAfter(graph, node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Insert][MemCpy]Node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
changed = true;
GELOGI("[Insert][Node]Insert mem copy node after %s.", node->GetName().c_str());
} else {
GELOGE(PARAM_INVALID, "[Check][InputOutput]Only support single input and single output, "
"node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Only support single input and single output, node:%s.", node->GetName().c_str());
return PARAM_INVALID;
}
}
}
if (changed) {
Status ret = graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str());
return ret;
}
}
return SUCCESS;
}

Status BufferPoolMemoryPass::GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph) {
std::unordered_map<std::string, std::unordered_map<int64_t, std::set<NodePtr>>> unique_calc_nodes;
for (const NodePtr &node : graph->GetAllNodes()) {
auto in_data_nodes = node->GetInAllNodes();
for (NodePtr &in_node : in_data_nodes) {
int64_t buffer_pool_id = 0;
int64_t buffer_pool_size = 0;
bool get_attr = AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id);
get_attr = get_attr && (AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size));
if (get_attr) {
std::string batch_label;
(void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
peer_buffer_node_item_[batch_label][node].emplace_back(BufferPoolNodeItem(in_node, 0, 0));
buffer_node_to_calc_[batch_label][in_node] = node;
if (unique_calc_nodes[batch_label][buffer_pool_id].count(node) == 0) {
calc_nodes_[batch_label][buffer_pool_id].emplace_back(node);
unique_calc_nodes[batch_label][buffer_pool_id].insert(node);
}
GELOGI("[Get][BufferNode]Calc node:%s, pool node:%s.", node->GetName().c_str(), in_node->GetName().c_str());
Status ret = SetBufferPoolSize(batch_label, buffer_pool_id, buffer_pool_size);
if (ret != SUCCESS) {
GELOGE(ret, "[Set][BufferPoolSize]Node:%s", in_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to set buffer pool size, something wrong with the info of node:%s",
in_node->GetName().c_str());
return ret;
}
}
}
}
return SUCCESS;
}

Status BufferPoolMemoryPass::SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size) {
auto iter = buffer_pool_size_[batch_label].find(id);
if (iter != buffer_pool_size_[batch_label].end() && iter->second != size) {
GELOGE(PARAM_INVALID, "[Check][BufferPoolSize]Get different size with the same id, "
"id:%ld, original size:%ld, this size:%ld.", id, iter->second, size);
REPORT_INNER_ERROR("E19999", "Get different size with the same id, "
"id:%ld, original size:%ld, this size:%ld.", id, iter->second, size);
return PARAM_INVALID;
}
buffer_pool_size_[batch_label][id] = size;
return SUCCESS;
}

Status BufferPoolMemoryPass::AllocateAllBufferPoolSpace() {
for (const auto &iter : calc_nodes_) {
std::string batch_label = iter.first;
Status ret = AllocateSpaceInBatch(calc_nodes_[batch_label],
buffer_pool_size_[batch_label],
buffer_node_to_calc_[batch_label],
peer_buffer_node_item_[batch_label]);
if (ret != SUCCESS) {
GELOGE(ret, "[Alloc][InBatch]Batch_label:%s.", batch_label.c_str());
REPORT_INNER_ERROR("E19999", "Failed to allocate space in batch, batch_label:%s.", batch_label.c_str());
return ret;
}
GELOGI("[Alloc][InBatch]Alloc space in batch successfully, batch label:%s.", batch_label.c_str());
}
return SUCCESS;
}

Status BufferPoolMemoryPass::AllocateSpaceInBatch(
const std::map<int64_t, std::vector<NodePtr>> &calc_nodes,
const std::unordered_map<int64_t, int64_t> &buffer_pool_size_map,
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc,
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item) {
for (const auto &calc_node_in_pool : calc_nodes) {
int64_t pool_id = calc_node_in_pool.first;
int64_t buffer_pool_size = buffer_pool_size_map.at(pool_id);
ClearQueue(mem_ctrl_event_);
ClearQueue(stream_ctrl_event_);
BufferPool buffer_pool(pool_id, buffer_pool_size, buffer_node_to_calc);
Status ret = AllocateSpaceInBufferPool(buffer_pool,
calc_node_in_pool.second,
buffer_pool_nodes_item);
if (ret != SUCCESS) {
GELOGE(ret, "[Alloc][InBufferPool]Pool id:%ld, pool size:%ld.", pool_id, buffer_pool_size);
REPORT_INNER_ERROR("E19999", "Failed to allocate space in buffer pool, id:%ld, pool size:%ld.",
pool_id, buffer_pool_size);
return ret;
}
GELOGI("[Alloc][InBufferPool]Alloc space in buffer pool successfully, pool id:%ld.", pool_id);
}
return SUCCESS;
}

Status BufferPoolMemoryPass::AllocateSpaceInBufferPool(
const BufferPool &buffer_pool,
const std::vector<NodePtr> &calc_nodes_in_pool,
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item) {
int64_t pool_id = buffer_pool.pool_id;
int64_t buffer_pool_size = buffer_pool.pool_size;
int64_t next_start = 0;
NodePtr pre_buffer_pool_node = nullptr;
std::queue<BufferPoolNodeItem> node_mem_range_in_pool;
node_mem_range_in_pool.push(BufferPoolMemoryPass::BufferPoolNodeItem(nullptr, 0, buffer_pool_size));
for (auto &calc_node : calc_nodes_in_pool) {
auto &peer_buffer_node_item = buffer_pool_nodes_item[calc_node];
std::unordered_map<int64_t, int64_t> calc_total_size;
size_t input_buffer_node_num = 0;
for (auto &node_item : peer_buffer_node_item) {
auto peer_buffer_node = node_item.node;
GE_CHECK_NOTNULL(peer_buffer_node);
int64_t total_size = 0;
++input_buffer_node_num;
Status ret = GetMemorySize(peer_buffer_node, total_size);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][MemSize]Node:%s, calc_node:%s.",
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to get memory size, node:%s, calc_node:%s.",
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
return ret;
}
ret = CheckBufferPoolSize(total_size, pool_id, buffer_pool_size, calc_total_size);
if (ret != SUCCESS) {
GELOGE(ret, "[Check][BufferPoolSize]Capacity is not enough for all data, calc_node:%s.",
calc_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Capacity is not enough for all data, calc_node:%s.",
calc_node->GetName().c_str());
return ret;
}
BufferPoolNodeItem buffer_pool_node_item(peer_buffer_node, calc_node, pre_buffer_pool_node, total_size,
0, 0, (input_buffer_node_num == peer_buffer_node_item.size()));
ret = AllocateSpaceForBufferPoolNode(next_start, buffer_pool, buffer_pool_node_item, node_mem_range_in_pool);
if (ret != SUCCESS) {
GELOGE(ret, "[Alloc][ForNode]Pool node:%s, calc_node:%s.",
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to allocate space for buffer pool node:%s, calc_node:%s.",
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
return ret;
}
pre_buffer_pool_node = peer_buffer_node;
}
}
return SUCCESS;
}

Status BufferPoolMemoryPass::AllocateSpaceForBufferPoolNode(int64_t &next_start,
const BufferPool buffer_pool,
BufferPoolNodeItem &buffer_pool_node_item,
std::queue<BufferPoolNodeItem> &node_mem_range_in_pool) {
// Get event id must be before FixTheTimingOfDependentNodes
uint32_t logic_event = logic_event_num_;
NodePtr buffer_node = buffer_pool_node_item.node;
NodePtr calc_node = buffer_pool_node_item.out_calc_node;
/// In the scenario where there are multiple PREFETCH operators in the inputs of the calculation operator,
/// the addition of events is optimized to only add events after the last PREFETCH operator.
/// w1 w2 w3 w4 w5
/// | | | | |
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 xxx
/// \ / \ / \ /
/// \ / \ / \ /
/// \ / \ / \ /
/// node1 node2 node3
/// | | |
/// | | |
/// --------------- other nodes ------------
///
/// The event id of the PREFETCH operator to the calculation operator needs to be generated before
/// FixTheTimingOfDependentNodes, because FixTheTimingOfDependentNodes may add a new id to stream_ctrl_event_,
/// and this id cannot be reused until the next PREFETCH operator in the sequence.
if (buffer_pool_node_item.is_last_input) {
logic_event = GenerateEventId(buffer_node->GetName(), stream_ctrl_event_);
node_event_multiplexing_[buffer_node].push_back(string("SendTo;" + calc_node->GetName() +
";" + std::to_string(logic_event)));
mem_ctrl_event_.push(std::make_pair(calc_node->GetName(), logic_event));
}
NodePtr dependent_calc_node = GetOffsetAndDependency(next_start, buffer_pool_node_item.total_size,
buffer_pool.pool_size,
buffer_pool.buffer_node_to_calc,
node_mem_range_in_pool);
if (dependent_calc_node != nullptr) {
Status ret = FixTheTimingOfDependentNodes(dependent_calc_node, buffer_node);
if (ret != SUCCESS) {
GELOGE(ret, "[Fix][Timing]Pool_id:%ld, pool node:%s, dependent node:%s.",
buffer_pool.pool_id, buffer_node->GetName().c_str(), dependent_calc_node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to fix timing, pool_id:%ld, pool node:%s, dependent node:%s.",
buffer_pool.pool_id, buffer_node->GetName().c_str(),
dependent_calc_node->GetName().c_str());
return ret;
}
}

buffer_pool_node_item.offset_start = next_start;
buffer_node_logical_offset_[buffer_node].push_back(buffer_pool_node_item.total_size);
buffer_node_logical_offset_[buffer_node].push_back(next_start);
FMK_INT64_ADDCHECK(next_start, buffer_pool_node_item.total_size);
next_start += buffer_pool_node_item.total_size;
buffer_pool_node_item.offset_end = next_start;
node_mem_range_in_pool.push(buffer_pool_node_item);
if (buffer_pool_node_item.pre_buffer_pool_node != nullptr) {
bool not_change = true;
auto ret = TryToFixNodeOrder(buffer_pool_node_item.pre_buffer_pool_node, buffer_node, not_change);
if (ret != SUCCESS) {
GELOGE(ret, "[Fix][BufferPoolNodeOrder]Pre node:%s, curr node:%s.",
buffer_pool_node_item.pre_buffer_pool_node->GetName().c_str(), buffer_node->GetName().c_str());
return ret;
}
}
GELOGI("[Alloc][ForNode]Buffer pool node %s send to %s, offset start:%ld, send event id:%u.",
buffer_node->GetName().c_str(), calc_node->GetName().c_str(),
buffer_pool_node_item.offset_start, logic_event);
return SUCCESS;
}

/// When generating the event ID, determine whether the name of the queue head node is the same as the name of
/// the operator, in order to handle such scenarios:
/// w1 w2 w3 w4 w5
/// | | | | |
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// | | | | |
/// node1 node2 node3 node4 node5
///
/// Memory distribution:
///
/// |____w1_____|__|
///
/// |____w2_____|__|
///
/// |____w3_____|__|
///
/// |______w4______|
///
/// |______w5______|
///
/// In this scenario, prefetch2 depends on node1. If the dependency is handled by adding an event of node1 to prefetch2,
/// the id sent by prefetch2 will be the same as the id it receives.Although Runtime supports this through WaitReset,
/// we consider this a dangerous operation and avoid it.
uint32_t BufferPoolMemoryPass::GenerateEventId(const std::string &node_name,
std::queue<std::pair<std::string, uint32_t>> &event_queue) {
uint32_t logic_event = logic_event_num_;
if (!event_queue.empty()) {
auto item = event_queue.front();
if (item.first != node_name) {
logic_event = item.second;
event_queue.pop();
return logic_event;
}
}
++logic_event_num_;
return logic_event;
}

NodePtr BufferPoolMemoryPass::GetOffsetAndDependency(int64_t &next_start,
int64_t total_mem_size,
int64_t buffer_pool_size,
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc,
std::queue<BufferPoolMemoryPass::BufferPoolNodeItem> &nodes_in_buffer) {
// The buffer pool can no longer fit this Tensor and needs to turn back.
if (next_start + total_mem_size > buffer_pool_size) {
next_start = 0;
if (!nodes_in_buffer.empty()) {
// Take up the rest of the space at the end,
nodes_in_buffer.back().offset_end = buffer_pool_size;
// Pop the first tensor memory in the previous round of the previous round.
nodes_in_buffer.pop();
}
while (!nodes_in_buffer.empty()) {
auto node_item = nodes_in_buffer.front();
// Go to the begin of previous round.
if (node_item.offset_start == 0) {
break;
}
nodes_in_buffer.pop();
}
}

while (!nodes_in_buffer.empty()) {
auto node_item = nodes_in_buffer.front();
if (next_start + total_mem_size <= node_item.offset_end) {
auto pool_node = node_item.node;
if (pool_node == nullptr) {
return nullptr;
}
auto output_calc = buffer_node_to_calc.find(pool_node);
if (output_calc != buffer_node_to_calc.end()) {
return output_calc->second;
}
return nullptr;
}
nodes_in_buffer.pop();
}
return nullptr;
}

Status BufferPoolMemoryPass::FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node) {
// The previous process ensures that all pointers are not null.
bool not_change = false;
Status ret = TryToFixNodeOrder(dependent_calc_node, curr_pool_node, not_change);
if (ret != SUCCESS) {
GELOGE(ret, "[Fix][NodeOrder]Src:%s, dst:%s.",
dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str());
return ret;
}
if (not_change) {
return SUCCESS;
}
uint32_t logic_event = GenerateEventId(dependent_calc_node->GetName(), mem_ctrl_event_);
node_event_multiplexing_[curr_pool_node].push_back(string("RecvFrom;" + dependent_calc_node->GetName() +
";" + std::to_string(logic_event)));
stream_ctrl_event_.push(std::make_pair(curr_pool_node->GetName(), logic_event));
GELOGI("[Fix][Timing]Add ctrl edge for buffer pool memory from %s to %s, buffer pool node recv event:%u.",
dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str(), logic_event);
return SUCCESS;
}

Status BufferPoolMemoryPass::SetResultOfMemoryAndEvent() {
for (auto &iter : node_event_multiplexing_) {
auto node = iter.first;
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
bool ret = AttrUtils::SetListStr(op_desc, ATTR_NAME_EVENT_MULTIPLEXING, iter.second);
if (!ret) {
GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to set event reuse info, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
auto offset_iter = buffer_node_logical_offset_.find(node);
if (offset_iter == buffer_node_logical_offset_.end()) {
GELOGE(INTERNAL_ERROR, "[Get][LogicalOffset]Node:%s.", node->GetName().c_str());
REPORT_INNER_ERROR("E19999", "Failed to get logical offset and size, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
ret = AttrUtils::SetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, offset_iter->second);
if (!ret) {
GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to set node memory offset and size, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
}
return SUCCESS;
}
} // namespace ge

+ 136
- 0
ge/graph/passes/buffer_pool_memory_pass.h View File

@@ -0,0 +1,136 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_
#define GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_

#include <queue>
#include "graph/graph.h"
#include "inc/graph_pass.h"

namespace ge {
class BufferPoolMemoryPass : public GraphPass {
public:
explicit BufferPoolMemoryPass() : logic_event_num_(0) {}

~BufferPoolMemoryPass() override = default;

struct BufferPool {
int64_t pool_id = 0;
int64_t pool_size = 0;
std::unordered_map<NodePtr, NodePtr> buffer_node_to_calc;
BufferPool(int64_t id, int64_t size, const std::unordered_map<NodePtr, NodePtr> &node_map)
: pool_id(id), pool_size(size), buffer_node_to_calc(node_map) {}
};

struct BufferPoolNodeItem {
NodePtr node = nullptr;
NodePtr out_calc_node = nullptr;
NodePtr pre_buffer_pool_node = nullptr;
int64_t total_size = 0;
int64_t offset_start = 0;
int64_t offset_end = 0;
bool is_last_input = true;
BufferPoolNodeItem(const NodePtr &buffer_n, const NodePtr &calc_n, const NodePtr &pre_buffer_n,
int64_t size, int64_t start, int64_t end, bool last)
: node(std::move(buffer_n)),
out_calc_node(std::move(calc_n)),
pre_buffer_pool_node(std::move(pre_buffer_n)),
total_size(size),
offset_start(start),
offset_end(end),
is_last_input(last) {}

BufferPoolNodeItem(const NodePtr &buffer_n, int64_t start, int64_t end)
: node(std::move(buffer_n)),
out_calc_node(nullptr),
pre_buffer_pool_node(nullptr),
total_size(0),
offset_start(start),
offset_end(end),
is_last_input(true) {}
};

Status Run(ComputeGraphPtr graph) override;

private:
static void ClearQueue(std::queue<std::pair<std::string, uint32_t>> &q);

static Status IsBufferPoolMemEnable(const ComputeGraphPtr &graph);

static Status CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size,
std::unordered_map<int64_t, int64_t> &calc_total_size);

static Status TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool &not_change);

Status InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node);

Status CopyOutForMultiUsedOutput(ComputeGraphPtr &graph);

Status GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph);

Status SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size);

Status AllocateAllBufferPoolSpace();

Status AllocateSpaceInBatch(const std::map<int64_t, std::vector<NodePtr>> &calc_nodes,
const std::unordered_map<int64_t, int64_t> &buffer_pool_size_map,
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc,
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item);

Status AllocateSpaceInBufferPool(const BufferPool &buffer_pool,
const std::vector<NodePtr> &calc_nodes_in_pool,
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item);

Status AllocateSpaceForBufferPoolNode(int64_t &next_start,
const BufferPool buffer_pool,
BufferPoolNodeItem &buffer_pool_node_item,
std::queue<BufferPoolNodeItem> &node_mem_range_in_pool);

NodePtr GetOffsetAndDependency(int64_t &next_start,
int64_t total_mem_size,
int64_t buffer_pool_size,
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc,
std::queue<BufferPoolNodeItem> &nodes_in_buffer);

Status FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node);

uint32_t GenerateEventId(const std::string &node_name, std::queue<std::pair<std::string, uint32_t>> &event_queue);

Status SetResultOfMemoryAndEvent();

// Use map to ensure that each visit is in the order of batch label and pool id
std::map<std::string, std::map<int64_t, std::vector<NodePtr>>> calc_nodes_;

std::unordered_map<std::string, std::unordered_map<NodePtr, NodePtr>> buffer_node_to_calc_;

std::unordered_map<std::string, std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>>> peer_buffer_node_item_;

std::unordered_map<std::string, std::unordered_map<int64_t, int64_t>> buffer_pool_size_;

uint32_t logic_event_num_;

std::queue<std::pair<std::string, uint32_t>> mem_ctrl_event_;

std::queue<std::pair<std::string, uint32_t>> stream_ctrl_event_;

std::unordered_map<NodePtr, std::vector<std::string>> node_event_multiplexing_;

std::unordered_map<NodePtr, std::vector<int64_t>> buffer_node_logical_offset_;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_

+ 5
- 0
tests/depends/runtime/src/runtime_stub.cc View File

@@ -43,6 +43,11 @@ rtError_t rtEventCreate(rtEvent_t *event) {
*event = new int[EVENT_LENTH];
return RT_ERROR_NONE;
}

rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag) {
return rtEventCreate(event);
}

rtError_t rtEventRecord(rtEvent_t event, rtStream_t stream) { return RT_ERROR_NONE; }

rtError_t rtEventSynchronize(rtEvent_t event) { return RT_ERROR_NONE; }


+ 5
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -276,6 +276,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc"
"${GE_CODE_DIR}/ge/model/ge_model.cc"
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc"
@@ -323,6 +324,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/build/memory/block_mem_assigner.cc"
"${GE_CODE_DIR}/ge/graph/build/memory/binary_block_mem_assigner.cc"
"${GE_CODE_DIR}/ge/graph/build/memory/max_block_mem_assigner.cc"
"${GE_CODE_DIR}/ge/graph/build/memory/buffer_pool_mem_assigner.cc"
"${GE_CODE_DIR}/ge/graph/manager/graph_mem_allocator.cc"
"${GE_CODE_DIR}/ge/graph/manager/graph_var_manager.cc"
"${GE_CODE_DIR}/ge/analyzer/analyzer.cc"
@@ -627,6 +629,7 @@ set(SINGLE_OP_SRC_FILES
# test files
set(COMMON_TEST_FILES
"graph/passes/graph_builder_utils.cc"
"graph/utils/buffer_pool_graph_builder.cc"
"test.cc"
)

@@ -703,6 +706,7 @@ set(PASS_TEST_FILES
"graph/passes/link_gen_mask_nodes_pass_unittest.cc"
"graph/passes/transpose_transdata_pass_unittest.cc"
"graph/passes/parallel_group_pass_unittest.cc"
"graph/passes/buffer_pool_memory_pass_unittest.cc"
)

set(KERNEL_TEST_FILES
@@ -771,6 +775,7 @@ set(MULTI_PARTS_TEST_FILES
"graph/build/model_builder_unittest.cc"
"graph/build/mem_assigner_unittest.cc"
"graph/build/task_generator_unittest.cc"
"graph/build/buffer_pool_mem_assigner_unittest.cc"
"graph/preprocess/graph_preprocess_unittest.cc"
"graph/manager/hcom_util_unittest.cc"
"graph/manager/graph_caching_allocator_unittest.cc"


+ 607
- 0
tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc View File

@@ -0,0 +1,607 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "../utils/buffer_pool_graph_builder.h"
#include "graph/passes/buffer_pool_memory_pass.h"

#define protected public
#define private public
#include "graph/build/memory/buffer_pool_mem_assigner.h"
#include "graph/build/memory/graph_mem_assigner.h"
#include "graph/build/stream_allocator.h"
#undef protected
#undef private

namespace ge {
namespace {
const int64_t kMemoryTypeHBM = static_cast<int64_t>(RT_MEMORY_HBM);
const int64_t kMemoryTypeP2P = static_cast<int64_t>(RT_MEMORY_P2P_HBM);
const int64_t kMemoryTypeDDR = static_cast<int64_t>(RT_MEMORY_DDR);
const size_t kOffsetHBM = 10240;
const size_t kOffsetP2P = 20480;
const size_t kOffsetDDR = 30720;
const int64_t kMemAlignSize = 512;

int64_t AlignMemSize(int64_t mem_size, int64_t align_size = kMemAlignSize) {
int64_t tmp = (mem_size + align_size - 1) / align_size * align_size;
return tmp;
}
int64_t AlignOutputMemSize(int64_t mem_size) {
int64_t tmp = (mem_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize;
// hccl need alignment
tmp = kMemAlignSize + tmp + kMemAlignSize;
return tmp;
}
} // namespace
class UtestBufferPoolMemAssignerTest : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}

};

TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_normal_assign_success) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM},
{kMemoryTypeP2P, kOffsetP2P}};
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize);
std::vector<int64_t> expect_offset = {(offset_base + 0),
(offset_base + AlignOutputMemSize(500)),
(offset_base + (AlignOutputMemSize(500) * 2)),
(offset_base + 0),
(offset_base + AlignOutputMemSize(1024))};

BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset);
ret = buffer_pool_mem_assigner.Assign();
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base +
AlignMemSize(5600, kMemAlignSize) + kMemAlignSize);

{
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(0));
}

{
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(1));
}

{
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(2));
}

{
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(3));
}

{
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(4));
}
}

TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_normal_graph_with_multi_buffer_pool_assign_success) {
ut::BufferPoolGraphBuilder builder("NormalGraphWithMultiBufferPool");
ge::ComputeGraphPtr graph = builder.BuildNormalGraphWithMultiBufferPool();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM},
{kMemoryTypeP2P, kOffsetP2P}};
int64_t offset_base_0 = static_cast<int64_t>(kOffsetHBM + kMemAlignSize);
int64_t offset_base_1 = static_cast<int64_t>(kOffsetHBM + kMemAlignSize) +
AlignMemSize(5000, kMemAlignSize) + kMemAlignSize;
std::vector<int64_t> expect_offset = {(offset_base_0 + 0),
(offset_base_1 + 0),
(offset_base_0 + AlignOutputMemSize(500)),
(offset_base_0 + 0),
(offset_base_1 + AlignOutputMemSize(500))};

BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset);
ret = buffer_pool_mem_assigner.Assign();
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base_1 +
AlignMemSize(5000, kMemAlignSize) + kMemAlignSize);

{
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(0));
}

{
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(1));
}

{
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(2));
}

{
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(3));
}

{
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(4));
}
}

TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success) {
ut::BufferPoolGraphBuilder builder("SerialGraph");
ge::ComputeGraphPtr graph = builder.BuildSerialGraph();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM},
{kMemoryTypeP2P, kOffsetP2P}};
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize);
std::vector<int64_t> expect_offset = {offset_base, offset_base, offset_base, offset_base, offset_base};

BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset);
ret = buffer_pool_mem_assigner.Assign();
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base +
AlignMemSize(2048, kMemAlignSize) + kMemAlignSize);

{
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(0));
}

{
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(1));
}

{
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(2));
}

{
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(3));
}

{
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(4));
}
}

TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) {
ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency");
ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM},
{kMemoryTypeP2P, kOffsetP2P}};
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize);
std::vector<int64_t> expect_offset = {(offset_base + 0),
(offset_base + AlignOutputMemSize(500)),
(offset_base + (AlignOutputMemSize(500) * 2)),
(offset_base + 0),
(offset_base + AlignOutputMemSize(1024))};

BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset);
ret = buffer_pool_mem_assigner.Assign();
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base +
AlignMemSize(5600, kMemAlignSize) + kMemAlignSize);

std::map<std::string, NodePtr> all_nodes;
for (auto node : graph->GetAllNodes()) {
EXPECT_NE(node, nullptr);
all_nodes[node->GetName()] = node;
}

{
auto prefetch = all_nodes.at("prefetch1");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(0));
}

{
auto prefetch = all_nodes.at("prefetch2");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(1));
}

{
auto prefetch = all_nodes.at("prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(2));
}

{
auto prefetch = all_nodes.at("prefetch4");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(3));
}

{
auto prefetch = all_nodes.at("prefetch5");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(4));
}
}

TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_graph_with_multi_batch_assign_success) {
ut::BufferPoolGraphBuilder builder("GraphWithMultiBatch");
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiBatch();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM},
{kMemoryTypeP2P, kOffsetP2P}};
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize);
std::vector<int64_t> expect_offset = {(offset_base + 0),
(offset_base + AlignOutputMemSize(500)),
(offset_base + (AlignOutputMemSize(500) * 2)),
(offset_base + 0),
(offset_base + AlignOutputMemSize(1024))};

BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset);
ret = buffer_pool_mem_assigner.Assign();
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base +
AlignMemSize(5600, kMemAlignSize) + kMemAlignSize);

{
auto prefetch = graph->FindNode("batch_label_128/prefetch1");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(0));
}

{
auto prefetch = graph->FindNode("batch_label_128/prefetch2");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(1));
}

{
auto prefetch = graph->FindNode("batch_label_128/prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(2));
}

{
auto prefetch = graph->FindNode("batch_label_128/prefetch4");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(3));
}

{
auto prefetch = graph->FindNode("batch_label_128/prefetch5");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(4));
}

{
auto prefetch = graph->FindNode("batch_label_256/prefetch1");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(0));
}

{
auto prefetch = graph->FindNode("batch_label_256/prefetch2");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(1));
}

{
auto prefetch = graph->FindNode("batch_label_256/prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(2));
}

{
auto prefetch = graph->FindNode("batch_label_256/prefetch4");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(3));
}

{
auto prefetch = graph->FindNode("batch_label_256/prefetch5");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset();
EXPECT_EQ(output_offset.size(), 1);
EXPECT_EQ(output_offset.at(0), expect_offset.at(4));
}
}

TEST_F(UtestBufferPoolMemAssignerTest, test_AssignBufferPoolMemory_success) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);
std::map<int64_t, MemoryOffset> memory_offset = {{kMemoryTypeHBM, MemoryOffset(RT_MEMORY_HBM, kOffsetHBM)},
{kMemoryTypeP2P, MemoryOffset(RT_MEMORY_P2P_HBM, kOffsetP2P)}};

GraphMemoryAssigner graph_memory_assigner(graph);
graph_memory_assigner.memory_offset_ = memory_offset;
ret = graph_memory_assigner.AssignBufferPoolMemory();
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestBufferPoolMemAssignerTest, test_AssignBufferPoolMemory_fail) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();
std::map<int64_t, MemoryOffset> memory_offset = {{kMemoryTypeHBM, MemoryOffset(RT_MEMORY_HBM, kOffsetHBM)},
{kMemoryTypeP2P, MemoryOffset(RT_MEMORY_P2P_HBM, kOffsetP2P)}};
{
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
std::vector<int64_t> type_list = {static_cast<int64_t>(RT_MEMORY_P2P_HBM)};
bool set_attr = ge::AttrUtils::SetListInt(prefetch->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list);
EXPECT_EQ(set_attr, true);

GraphMemoryAssigner graph_memory_assigner(graph);
graph_memory_assigner.memory_offset_ = memory_offset;
Status ret = graph_memory_assigner.AssignBufferPoolMemory();
EXPECT_EQ(ret, FAILED);
}

{
std::vector<std::string> node_list = {"prefetch1", "prefetch2", "prefetch3", "prefetch4", "prefetch5"};
std::vector<int64_t> type_list = {static_cast<int64_t>(RT_MEMORY_L1)};
for (auto &node_name : node_list) {
auto prefetch = graph->FindNode(node_name);
EXPECT_NE(prefetch, nullptr);
EXPECT_NE(prefetch->GetOpDesc(), nullptr);
bool set_attr = ge::AttrUtils::SetListInt(prefetch->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list);
EXPECT_EQ(set_attr, true);
}
GraphMemoryAssigner graph_memory_assigner(graph);
graph_memory_assigner.memory_offset_ = memory_offset;
Status ret = graph_memory_assigner.AssignBufferPoolMemory();
EXPECT_EQ(ret, FAILED);
}
}

TEST_F(UtestBufferPoolMemAssignerTest, test_RefreshEventsWithReuse_success) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

std::map<std::string, NodePtr> all_nodes;
for (auto node : graph->GetAllNodes()) {
EXPECT_NE(node, nullptr);
all_nodes[node->GetName()] = node;
}

Graph2SubGraphInfoList sub_graphs;
StreamAllocator stream_allocator(graph, sub_graphs);
stream_allocator.event_num_ = 65520;

// stream ctrl event
stream_allocator.AddSendEventId(all_nodes.at("prefetch1"), 30);
stream_allocator.AddRecvEventId(all_nodes.at("add1"), 30);

stream_allocator.AddSendEventId(all_nodes.at("prefetch2"), 31);
stream_allocator.AddRecvEventId(all_nodes.at("add2"), 31);

stream_allocator.AddSendEventId(all_nodes.at("prefetch3"), 32);
stream_allocator.AddRecvEventId(all_nodes.at("add3"), 32);

stream_allocator.AddSendEventId(all_nodes.at("prefetch4"), 33);
stream_allocator.AddRecvEventId(all_nodes.at("add4"), 33);

stream_allocator.AddSendEventId(all_nodes.at("add2"), 34);
stream_allocator.AddRecvEventId(all_nodes.at("prefetch4"), 34);

stream_allocator.AddSendEventId(all_nodes.at("prefetch5"), 35);
stream_allocator.AddRecvEventId(all_nodes.at("add5"), 35);

stream_allocator.AddSendEventId(all_nodes.at("add3"), 36);
stream_allocator.AddRecvEventId(all_nodes.at("prefetch5"), 36);

// other event
stream_allocator.AddSendEventId(all_nodes.at("prefetch1"), 37);
stream_allocator.AddRecvEventId(all_nodes.at("add5"), 37);


ret = stream_allocator.RefreshEventsWithReuse();
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ((stream_allocator.node_to_send_events_.at(all_nodes.at("prefetch1"))).size(), 2);
EXPECT_EQ((stream_allocator.node_to_send_events_.at(all_nodes.at("prefetch5"))).size(), 1);
EXPECT_EQ((stream_allocator.node_to_recv_events_.at(all_nodes.at("prefetch5"))).size(), 1);
EXPECT_EQ((stream_allocator.node_to_recv_events_.at(all_nodes.at("add5"))).size(), 2);
EXPECT_EQ(stream_allocator.event_num_, 5);
}

TEST_F(UtestBufferPoolMemAssignerTest, test_RefreshEventsWithReuse_fail) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();

std::map<std::string, NodePtr> all_nodes;
for (auto node : graph->GetAllNodes()) {
EXPECT_NE(node, nullptr);
all_nodes[node->GetName()] = node;
}
std::vector<std::vector<std::string>> event_info = {{"SendTo;add1;0"},
{"SendTo;add2;1"},
{"SendTo;add3;2"},
{"SendTo;add4;3", "RecvFrom;add2;0"},
{"SendTo;add5;0", "RecvFrom;add3;1"}};

(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]);
(void) AttrUtils::SetListStr(all_nodes.at("prefetch2")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[1]);
(void) AttrUtils::SetListStr(all_nodes.at("prefetch3")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[2]);
(void) AttrUtils::SetListStr(all_nodes.at("prefetch4")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[3]);
(void) AttrUtils::SetListStr(all_nodes.at("prefetch5")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[4]);

Graph2SubGraphInfoList sub_graphs;
StreamAllocator stream_allocator(graph, sub_graphs);
stream_allocator.event_num_ = 65520;

// Item num of raw event info is invalid
event_info[0][0] = "SendTo;add1;0;1";
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]);
Status ret = stream_allocator.RefreshEventsWithReuse();
EXPECT_EQ(ret, PARAM_INVALID);

// Event id is invalid argument
event_info[0][0] = "SendTo;add1;event_id";
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]);
ret = stream_allocator.RefreshEventsWithReuse();
EXPECT_EQ(ret, PARAM_INVALID);

// Event id is out of range
event_info[0][0] = "SendTo;add1;666666666666666666666666666666666666666";
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]);
ret = stream_allocator.RefreshEventsWithReuse();
EXPECT_EQ(ret, PARAM_INVALID);

// Event id is negative
event_info[0][0] = "SendTo;add1;-2";
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]);
ret = stream_allocator.RefreshEventsWithReuse();
EXPECT_EQ(ret, PARAM_INVALID);

// Key word is not supported
event_info[0][0] = "SendToKey;add1;2";
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]);
ret = stream_allocator.RefreshEventsWithReuse();
EXPECT_EQ(ret, PARAM_INVALID);
}
} // namespace ge


+ 591
- 0
tests/ut/ge/graph/passes/buffer_pool_memory_pass_unittest.cc View File

@@ -0,0 +1,591 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "inc/pass_manager.h"
#include "graph_builder_utils.h"
#include "../utils/buffer_pool_graph_builder.h"
#include "graph/passes/buffer_pool_memory_pass.h"

namespace ge {
class UtestBufferPoolMemoryPass : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_normal_success_test) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add1;0");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add2;1");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add3;2");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add4;3");
EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add5;0");
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;1");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_normal_graph_with_multi_buffer_pool_success_test) {
ut::BufferPoolGraphBuilder builder("NormalGraphWithMultiBufferPool");
ge::ComputeGraphPtr graph = builder.BuildNormalGraphWithMultiBufferPool();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add1;0");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add2;3");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add3;1");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add4;2");
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;0");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add5;4");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_contain_one_node_success_test) {
ut::BufferPoolGraphBuilder builder("SerialGraph");
ge::ComputeGraphPtr graph = builder.BuildSerialGraph();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add1;0");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add2;1");
EXPECT_EQ(event_info.at(1), "RecvFrom;add1;2");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add1");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add3;2");
EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add4;0");
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;1");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add5;1");
EXPECT_EQ(event_info.at(1), "RecvFrom;add4;2");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add4");
}
}

TEST_F(UtestBufferPoolMemoryPass, calc_node_with_multi_buffer_pool_input_success_test) {
ut::BufferPoolGraphBuilder builder("GraphWithMultiPrefetch");
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiPrefetch();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 0);
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add1;0");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 0);
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add2;1");
EXPECT_EQ(event_info.at(1), "RecvFrom;add1;2");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add1");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add3;2");
EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_in_different_subgraph_success_test) {
ut::BufferPoolGraphBuilder builder("GraphWithSubgraph");
ge::ComputeGraphPtr graph = builder.BuildGraphWithSubgraph();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

std::map<std::string, NodePtr> all_nodes;
for (auto node : graph->GetAllNodes()) {
EXPECT_NE(node, nullptr);
all_nodes[node->GetName()] = node;
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add1;0");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add2;1");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add3;2");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add4;3");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 0);
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add5;4");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 1);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch4");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_in_different_subgraph_with_inner_dependency_success_test) {
ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency");
ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

std::map<std::string, NodePtr> all_nodes;
for (auto node : graph->GetAllNodes()) {
EXPECT_NE(node, nullptr);
all_nodes[node->GetName()] = node;
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add1;0");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add2;1");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add3;2");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;add4;3");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 1);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch3");
}

{
std::vector<std::string> event_info;
auto prefetch = all_nodes.at("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add5;4");
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;0");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_with_batch_label_success_test) {
ut::BufferPoolGraphBuilder builder("GraphWithMultiBatch");
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiBatch();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("batch_label_256/prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add1;4");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("batch_label_256/prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add2;5");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("batch_label_256/prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add3;6");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("batch_label_256/prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add4;7");
EXPECT_EQ(event_info.at(1), "RecvFrom;batch_label_256/add2;4");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "batch_label_256/add2");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("batch_label_256/prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add5;4");
EXPECT_EQ(event_info.at(1), "RecvFrom;batch_label_256/add3;5");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "batch_label_256/add3");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_node_has_multi_output_success_test) {
ut::BufferPoolGraphBuilder builder("GraphWithMultiOutputPrefetch");
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiOutputPrefetch();

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, SUCCESS);

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch1");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;prefetch1_memcpy_async;0");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch2");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;prefetch2_memcpy_async;1");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 1);
EXPECT_EQ(event_info.at(0), "SendTo;prefetch3_memcpy_async;2");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch4");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;prefetch4_memcpy_async;3");
EXPECT_EQ(event_info.at(1), "RecvFrom;prefetch2_memcpy_async;0");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch2_memcpy_async");
}

{
std::vector<std::string> event_info;
auto prefetch = graph->FindNode("prefetch5");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info);
EXPECT_EQ(event_info.size(), 2);
EXPECT_EQ(event_info.at(0), "SendTo;add5;0");
EXPECT_EQ(event_info.at(1), "RecvFrom;prefetch3_memcpy_async;1");
auto in_ctrl_nodes = prefetch->GetInControlNodes();
EXPECT_EQ(in_ctrl_nodes.size(), 2);
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch3_memcpy_async");
}
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_has_different_size_fail_test) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();
const int64_t dummy_size = 256;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
(void) AttrUtils::SetInt(prefetch->GetOpDesc(), "_buffer_pool_size", dummy_size);

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_size_is_not_enough_fail_test) {
ut::BufferPoolGraphBuilder builder("NormalGraph");
ge::ComputeGraphPtr graph = builder.BuildNormalGraph();
const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
builder.SetPrefetchNodeInfo(prefetch, buffer_pool_id, buffer_pool_size, {buffer_pool_size + 512});

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_size_is_not_enough_for_multi_fail_test) {
ut::BufferPoolGraphBuilder builder("GraphWithMultiPrefetch");
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiPrefetch();
const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;
auto prefetch = graph->FindNode("prefetch3");
EXPECT_NE(prefetch, nullptr);
builder.SetPrefetchNodeInfo(prefetch, buffer_pool_id, buffer_pool_size, {buffer_pool_size});

BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestBufferPoolMemoryPass, buffer_pool_node_has_multi_input_output_fail_test) {
ut::BufferPoolGraphBuilder builder("GraphWithMultiInputOutputPrefetch");
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiInputOutputPrefetch();
BufferPoolMemoryPass buffer_pool_mem_pass;
Status ret = buffer_pool_mem_pass.Run(graph);
EXPECT_EQ(ret, FAILED);
}
} // namespace ge

+ 978
- 0
tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc View File

@@ -0,0 +1,978 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "buffer_pool_graph_builder.h"
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"

namespace ge {
namespace ut {
BufferPoolGraphBuilder::BufferPoolGraphBuilder(const std::string &name) {
graph_name_ = name;
}

BufferPoolGraphBuilder::InnerGraphBuilder::InnerGraphBuilder(const std::string &name) {
graph_ = std::make_shared<ComputeGraph>(name);
EXPECT_NE(graph_, nullptr);
}

NodePtr BufferPoolGraphBuilder::InnerGraphBuilder::AddNode(const std::string &name, const std::string &type,
int in_cnt, int out_cnt,
Format format, DataType data_type,
std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
EXPECT_NE(tensor_desc, nullptr);
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);
auto op_desc = std::make_shared<OpDesc>(name, type);
EXPECT_NE(op_desc, nullptr);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}
return graph_->AddNode(op_desc);
}

void BufferPoolGraphBuilder::InnerGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx,
NodePtr &dst_node, int dst_idx) {
EXPECT_NE(src_node, nullptr);
EXPECT_NE(dst_node, nullptr);
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}

void BufferPoolGraphBuilder::InnerGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) {
EXPECT_NE(src_node, nullptr);
EXPECT_NE(dst_node, nullptr);
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
}

void BufferPoolGraphBuilder::SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size,
const std::string &batch_label) {
EXPECT_NE(node, nullptr);
(void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, pool_id);
(void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, pool_size);
if (!batch_label.empty()) {
(void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
}
}

void BufferPoolGraphBuilder::SetBatchLabel(NodePtr &node, const std::string &batch_label) {
EXPECT_NE(node, nullptr);
(void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);

}

void BufferPoolGraphBuilder::SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size) {
EXPECT_NE(node, nullptr);
EXPECT_NE(node->GetOpDesc(), nullptr);
size_t output_size = node->GetOpDesc()->GetOutputsSize();
EXPECT_EQ(output_size, mem_size.size());
for (size_t i = 0; i < output_size; ++i) {
auto output_op_desc = node->GetOpDesc()->MutableOutputDesc(i);
ge::TensorUtils::SetSize(*output_op_desc, mem_size[i]);
}
}

void BufferPoolGraphBuilder::SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes) {
EXPECT_NE(node, nullptr);
EXPECT_NE(node->GetOpDesc(), nullptr);
node->GetOpDesc()->SetWorkspaceBytes(ws_bytes);
}

void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size,
const std::vector<int64_t> &mem_size,
const std::vector<int64_t> &ws_bytes,
const std::string &batch_label) {
SetBufferPool(node, pool_id, pool_size, batch_label);
SetOutputMemSize(node, mem_size);
SetWorkSpaceMemSize(node, ws_bytes);
}

///
/// Normal graph
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto add1 = builder.AddNode("add1", ADD, 2, 1);
auto add2 = builder.AddNode("add2", ADD, 2, 1);
auto add3 = builder.AddNode("add3", ADD, 2, 1);
auto add4 = builder.AddNode("add4", ADD, 2, 1);
auto add5 = builder.AddNode("add5", ADD, 2, 1);
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(add1, 0, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(add2, 0, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(add3, 0, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(add4, 0, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add5, 0, net_output, 0);

auto compute_graph = builder.GetGraph();

return compute_graph;
}

///
/// Normal graph with multi buffer pool
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// (pool0) (pool1) (pool0) (pool0) (pool1)
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
/// Memory distribution:
///
/// |___w1__|__w3__|_________|
/// |_____w4_____|___________|
///
/// |___w2__|_____w5___|_____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

const int64_t buffer_pool_id_0 = 0;
const int64_t buffer_pool_id_1 = 1;
const int64_t buffer_pool_size = 5000;

auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id_0, buffer_pool_size, {500});
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id_1, buffer_pool_size, {500});
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id_0, buffer_pool_size, {500});
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id_0, buffer_pool_size, {1024});
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id_1, buffer_pool_size, {1024});

auto add1 = builder.AddNode("add1", ADD, 2, 1);
auto add2 = builder.AddNode("add2", ADD, 2, 1);
auto add3 = builder.AddNode("add3", ADD, 2, 1);
auto add4 = builder.AddNode("add4", ADD, 2, 1);
auto add5 = builder.AddNode("add5", ADD, 2, 1);
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(add1, 0, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(add2, 0, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(add3, 0, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(add4, 0, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add5, 0, net_output, 0);

auto compute_graph = builder.GetGraph();

return compute_graph;
}

///
/// SerialGraph: Buffer pool size only can contain one prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
/// Memory distribution:
///
/// |____w1_____|__|
///
/// |____w2_____|__|
///
/// |____w3_____|__|
///
/// |______w4______|
///
/// |______w5______|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 2048;

auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto add1 = builder.AddNode("add1", ADD, 2, 1);
auto add2 = builder.AddNode("add2", ADD, 2, 1);
auto add3 = builder.AddNode("add3", ADD, 2, 1);
auto add4 = builder.AddNode("add4", ADD, 2, 1);
auto add5 = builder.AddNode("add5", ADD, 2, 1);
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(add1, 0, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(add2, 0, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(add3, 0, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(add4, 0, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add5, 0, net_output, 0);

auto compute_graph = builder.GetGraph();

return compute_graph;
}

///
/// GraphWithMultiPrefetch: Calc node with more prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
/// \ / \ / \ /
/// \ / \ / \ /
/// \ / \ / \ /
/// add1 ------ c ------- add2 ----- c ----- add3
/// | | |
/// | | |
/// --------------- net_output ------------
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto add1 = builder.AddNode("add1", ADD, 2, 1);
auto add2 = builder.AddNode("add2", ADD, 2, 1);
auto add3 = builder.AddNode("add3", ADD, 2, 1);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 3, 0);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(prefetch1, 0, add1, 0);
builder.AddDataEdge(prefetch2, 0, add1, 1);

builder.AddDataEdge(prefetch3, 0, add2, 0);
builder.AddDataEdge(prefetch4, 0, add2, 1);

builder.AddDataEdge(const1, 0, add3, 0);
builder.AddDataEdge(prefetch5, 0, add3, 1);

builder.AddDataEdge(add1, 0, net_output, 0);
builder.AddDataEdge(add2, 0, net_output, 1);
builder.AddDataEdge(add3, 0, net_output, 2);

builder.AddControlEdge(add1, add2);
builder.AddControlEdge(add2, add3);

auto compute_graph = builder.GetGraph();

return compute_graph;
}

///
/// GraphWithSubgraph: Calc node in different subgraph
///
///
/// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output
///
///
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() {
auto builder = InnerGraphBuilder(graph_name_);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

// Subgraph1
auto subgraph_builder1 = InnerGraphBuilder("Subgraph1");
auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1);
auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1);
auto w3 = subgraph_builder1.AddNode("w3", VARIABLE, 0, 1);

auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
auto prefetch3 = subgraph_builder1.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0);
auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1);

auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1);
auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1);
auto add3 = subgraph_builder1.AddNode("add3", ADD, 2, 1);

subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0);
subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0);
subgraph_builder1.AddDataEdge(w3, 0, prefetch3, 0);
subgraph_builder1.AddDataEdge(const1, 0, add1, 0);
subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1);
subgraph_builder1.AddDataEdge(add1, 0, add2, 0);
subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1);
subgraph_builder1.AddDataEdge(add2, 0, add3, 0);
subgraph_builder1.AddDataEdge(prefetch3, 0, add3, 1);
subgraph_builder1.AddDataEdge(add3, 0, subgraph1_out, 0);
auto subgraph1 = subgraph_builder1.GetGraph();
for (auto &node : subgraph1->GetDirectNode()) {
node->SetOwnerComputeGraph(subgraph1);
}

// Subgraph2
auto subgraph_builder2 = InnerGraphBuilder("Subgraph2");
auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1);
auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1);

auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1);
auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1);
auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1);
auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1);

subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0);
subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0);
subgraph_builder2.AddDataEdge(data1, 0, add4, 0);
subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1);
subgraph_builder2.AddDataEdge(add4, 0, add5, 0);
subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1);
subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0);

auto subgraph2 = subgraph_builder2.GetGraph();
for (auto &node : subgraph2->GetDirectNode()) {
node->SetOwnerComputeGraph(subgraph2);
}

// root graph
auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1);
auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
builder.AddDataEdge(call_node1, 0, call_node2, 0);
builder.AddDataEdge(call_node2, 0, net_output, 0);
auto compute_graph = builder.GetGraph();
call_node1->SetOwnerComputeGraph(compute_graph);
call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName());
call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName());
call_node2->SetOwnerComputeGraph(compute_graph);
call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName());
call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName());

subgraph1->SetParentNode(call_node1);
subgraph1->SetParentGraph(compute_graph);
subgraph2->SetParentNode(call_node2);
subgraph2->SetParentGraph(compute_graph);
compute_graph->AddSubGraph(subgraph1);
compute_graph->AddSubGraph(subgraph2);

return compute_graph;
}

///
/// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency
///
///
/// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output
///
///
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() {
auto builder = InnerGraphBuilder(graph_name_);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

// Subgraph1
auto subgraph_builder1 = InnerGraphBuilder("Subgraph1");
auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1);
auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1);

auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0);
auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1);

auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1);
auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1);

subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0);
subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0);
subgraph_builder1.AddDataEdge(const1, 0, add1, 0);
subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1);
subgraph_builder1.AddDataEdge(add1, 0, add2, 0);
subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1);
subgraph_builder1.AddDataEdge(add2, 0, subgraph1_out, 0);
auto subgraph1 = subgraph_builder1.GetGraph();
for (auto &node : subgraph1->GetDirectNode()) {
node->SetOwnerComputeGraph(subgraph1);
}

// Subgraph2
auto subgraph_builder2 = InnerGraphBuilder("Subgraph2");
auto w3 = subgraph_builder2.AddNode("w3", VARIABLE, 0, 1);
auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1);
auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1);

auto prefetch3 = subgraph_builder2.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto add3 = subgraph_builder2.AddNode("add3", ADD, 2, 1);
auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1);
auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1);
auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1);
auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1);

subgraph_builder2.AddDataEdge(w3, 0, prefetch3, 0);
subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0);
subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0);
subgraph_builder2.AddDataEdge(data1, 0, add3, 0);
subgraph_builder2.AddDataEdge(prefetch3, 0, add3, 1);
subgraph_builder2.AddDataEdge(add3, 0, add4, 0);
subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1);
subgraph_builder2.AddDataEdge(add4, 0, add5, 0);
subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1);
subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0);

auto subgraph2 = subgraph_builder2.GetGraph();
for (auto &node : subgraph2->GetDirectNode()) {
node->SetOwnerComputeGraph(subgraph2);
}

// root graph
auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1);
auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0);
auto net_output = subgraph_builder2.AddNode("net_output", NETOUTPUT, 1, 0);
builder.AddDataEdge(call_node1, 0, call_node2, 0);
builder.AddDataEdge(call_node2, 0, net_output, 0);
auto compute_graph = builder.GetGraph();
call_node1->SetOwnerComputeGraph(compute_graph);
call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName());
call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName());
call_node2->SetOwnerComputeGraph(compute_graph);
call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName());
call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName());

subgraph1->SetParentNode(call_node1);
subgraph1->SetParentGraph(compute_graph);
subgraph2->SetParentNode(call_node2);
subgraph2->SetParentGraph(compute_graph);
compute_graph->AddSubGraph(subgraph1);
compute_graph->AddSubGraph(subgraph2);

return compute_graph;
}

///
/// BuildGraphWithMultiBatch: Different batch label
///
///
/// batch_label_128
///
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
/// / / / / / / \
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \
/// const1 switch_false / / / / / \
/// \ / / / / / / \
/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
/// / \ \ \ \ \ \ /
/// const2 switch_true \ \ \ \ \ /
/// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /
/// \ \ \ \ \ \ /
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
///
/// batch_label_256
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto const2 = builder.AddNode("const2", CONSTANTOP, 0, 1);
auto switch1 = builder.AddNode("switch1", SWITCH, 2, 2);
auto switch_false = builder.AddNode("switch_false", IDENTITY, 1, 1);
auto switch_true = builder.AddNode("switch_true", IDENTITY, 1, 1);
auto merge1 = builder.AddNode("merge1", MERGE, 2, 2);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);

builder.AddDataEdge(const1, 0, switch1, 0);
builder.AddDataEdge(const2, 0, switch1, 1);
builder.AddDataEdge(switch1, 0, switch_false, 0);
builder.AddDataEdge(switch1, 1, switch_true, 0);
builder.AddDataEdge(merge1, 0, net_output, 0);

std::string batch_label_128 = "batch_128";
std::string batch_label_256 = "batch_256";

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

{
auto prefetch1 = builder.AddNode("batch_label_128/prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
auto prefetch2 = builder.AddNode("batch_label_128/prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
auto prefetch3 = builder.AddNode("batch_label_128/prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
auto prefetch4 = builder.AddNode("batch_label_128/prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128);
auto prefetch5 = builder.AddNode("batch_label_128/prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128);

auto add1 = builder.AddNode("batch_label_128/add1", ADD, 2, 1);
SetBatchLabel(add1, batch_label_128);
auto add2 = builder.AddNode("batch_label_128/add2", ADD, 2, 1);
SetBatchLabel(add2, batch_label_128);
auto add3 = builder.AddNode("batch_label_128/add3", ADD, 2, 1);
SetBatchLabel(add3, batch_label_128);
auto add4 = builder.AddNode("batch_label_128/add4", ADD, 2, 1);
SetBatchLabel(add4, batch_label_128);
auto add5 = builder.AddNode("batch_label_128/add5", ADD, 2, 1);
SetBatchLabel(add5, batch_label_128);
auto const1 = builder.AddNode("batch_label_128/const1", CONSTANTOP, 0, 1);
SetBatchLabel(const1, batch_label_128);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(add1, 0, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(add2, 0, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(add3, 0, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(add4, 0, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add5, 0, merge1, 0);
builder.AddControlEdge(switch_false, const1);
}

{
auto prefetch1 = builder.AddNode("batch_label_256/prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
auto prefetch2 = builder.AddNode("batch_label_256/prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
auto prefetch3 = builder.AddNode("batch_label_256/prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
auto prefetch4 = builder.AddNode("batch_label_256/prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256);
auto prefetch5 = builder.AddNode("batch_label_256/prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256);

auto add1 = builder.AddNode("batch_label_256/add1", ADD, 2, 1);
SetBatchLabel(add1, batch_label_256);
auto add2 = builder.AddNode("batch_label_256/add2", ADD, 2, 1);
SetBatchLabel(add2, batch_label_256);
auto add3 = builder.AddNode("batch_label_256/add3", ADD, 2, 1);
SetBatchLabel(add3, batch_label_256);
auto add4 = builder.AddNode("batch_label_256/add4", ADD, 2, 1);
SetBatchLabel(add4, batch_label_256);
auto add5 = builder.AddNode("batch_label_256/add5", ADD, 2, 1);
SetBatchLabel(add5, batch_label_256);
auto const1 = builder.AddNode("batch_label_256/const1", CONSTANTOP, 0, 1);
SetBatchLabel(const1, batch_label_128);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(add1, 0, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(add2, 0, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(add3, 0, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(add4, 0, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add5, 0, merge1, 1);

builder.AddControlEdge(switch_true, const1);
}

auto compute_graph = builder.GetGraph();

return compute_graph;
}

///
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /
/// const1 ----- add1 add2 add3 add4 add5
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// -------------- net_output ---------------
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto add1 = builder.AddNode("add1", ADD, 2, 1);
auto add2 = builder.AddNode("add2", ADD, 2, 1);
auto add3 = builder.AddNode("add3", ADD, 2, 1);
auto add4 = builder.AddNode("add4", ADD, 2, 1);
auto add5 = builder.AddNode("add5", ADD, 2, 1);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(prefetch1, 0, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(prefetch2, 0, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(prefetch3, 0, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(prefetch4, 0, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add1, 0, net_output, 0);
builder.AddDataEdge(add2, 0, net_output, 1);
builder.AddDataEdge(add3, 0, net_output, 2);
builder.AddDataEdge(add4, 0, net_output, 3);
builder.AddDataEdge(add5, 0, net_output, 4);

auto compute_graph = builder.GetGraph();

return compute_graph;
}

///
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ / \ / \ / \ / \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /
/// const1 ----- add1 add2 add3 add4 add5
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// -------------- net_output ---------------
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiInputOutputPrefetch() {
auto builder = InnerGraphBuilder(graph_name_);
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);

const int64_t buffer_pool_id = 0;
const int64_t buffer_pool_size = 5600;

auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 2, 2);
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500, 500});
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 2, 2);
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500, 500});
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 2, 2);
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500, 1024});
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 2, 2);
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024, 1024});
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});

auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
auto add1 = builder.AddNode("add1", ADD, 2, 1);
auto add2 = builder.AddNode("add2", ADD, 2, 1);
auto add3 = builder.AddNode("add3", ADD, 2, 1);
auto add4 = builder.AddNode("add4", ADD, 2, 1);
auto add5 = builder.AddNode("add5", ADD, 2, 1);
auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0);

builder.AddDataEdge(w1, 0, prefetch1, 0);
builder.AddDataEdge(w2, 0, prefetch1, 1);
builder.AddDataEdge(w2, 0, prefetch2, 0);
builder.AddDataEdge(w3, 0, prefetch2, 1);
builder.AddDataEdge(w3, 0, prefetch3, 0);
builder.AddDataEdge(w4, 0, prefetch3, 1);
builder.AddDataEdge(w4, 0, prefetch4, 0);
builder.AddDataEdge(w5, 0, prefetch4, 1);
builder.AddDataEdge(w5, 0, prefetch5, 0);

builder.AddDataEdge(const1, 0, add1, 0);
builder.AddDataEdge(prefetch1, 0, add1, 1);

builder.AddDataEdge(prefetch1, 1, add2, 0);
builder.AddDataEdge(prefetch2, 0, add2, 1);

builder.AddDataEdge(prefetch2, 1, add3, 0);
builder.AddDataEdge(prefetch3, 0, add3, 1);

builder.AddDataEdge(prefetch3, 1, add4, 0);
builder.AddDataEdge(prefetch4, 0, add4, 1);

builder.AddDataEdge(prefetch4, 1, add5, 0);
builder.AddDataEdge(prefetch5, 0, add5, 1);

builder.AddDataEdge(add1, 0, net_output, 0);
builder.AddDataEdge(add2, 0, net_output, 1);
builder.AddDataEdge(add3, 0, net_output, 2);
builder.AddDataEdge(add4, 0, net_output, 3);
builder.AddDataEdge(add5, 0, net_output, 4);

auto compute_graph = builder.GetGraph();

return compute_graph;
}
} // namespace ut
} // namespace ge

+ 279
- 0
tests/ut/ge/graph/utils/buffer_pool_graph_builder.h View File

@@ -0,0 +1,279 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_
#define GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_

#include <string>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
namespace ut {
class BufferPoolGraphBuilder {
public:
explicit BufferPoolGraphBuilder(const std::string &name = "BufferPoolGraph");
~BufferPoolGraphBuilder() {}
class InnerGraphBuilder {
public:
explicit InnerGraphBuilder(const std::string &name);
~InnerGraphBuilder() {}
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});

void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx);

void AddControlEdge(NodePtr &src_node, NodePtr &dst_node);

ComputeGraphPtr GetGraph() {
graph_->TopologicalSorting();
return graph_;
}
private:
ComputeGraphPtr graph_;
};

///
/// Normal graph
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildNormalGraph();

///
/// Normal graph with multi buffer pool
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// (pool0) (pool1) (pool0) (pool0) (pool1)
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
/// Memory distribution:
///
/// |___w1__|__w3__|_________|
/// |_____w4_____|___________|
///
/// |___w2__|_____w5___|_____|
///
ComputeGraphPtr BuildNormalGraphWithMultiBufferPool();

///
/// SerialGraph: Buffer pool size only can contain one prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
/// Memory distribution:
///
/// |____w1_____|__|
///
/// |____w2_____|__|
///
/// |____w3_____|__|
///
/// |______w4______|
///
/// |______w5______|
///
ComputeGraphPtr BuildSerialGraph();

///
/// GraphWithMultiPrefetch: Calc node with more prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
/// \ / \ / \ /
/// \ / \ / \ /
/// \ / \ / \ /
/// add1 ------ c ------- add2 ----- c ----- add3
/// | | |
/// | | |
/// --------------- net_output ------------
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildGraphWithMultiPrefetch();

///
/// GraphWithSubgraph: Calc node in different subgraph
///
///
/// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output
///
///
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildGraphWithSubgraph();

///
/// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency
///
///
/// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output
///
///
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildSubgraphWithInnerDependency();

///
/// BuildGraphWithMultiBatch: Different batch label
///
///
/// batch_label_128
///
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
/// / / / / / / \
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \
/// const1 switch_false / / / / / \
/// \ / / / / / / \
/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
/// / \ \ \ \ \ \ /
/// const2 switch_true \ \ \ \ \ /
/// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /
/// \ \ \ \ \ \ /
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
///
/// batch_label_256
///
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildGraphWithMultiBatch();

///
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /
/// const1 ----- add1 add2 add3 add4 add5
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// -------------- net_output ---------------
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildGraphWithMultiOutputPrefetch();

///
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ / \ / \ / \ / \
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /
/// const1 ----- add1 add2 add3 add4 add5
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// | \ | / |
/// -------------- net_output ---------------
///
/// Memory distribution:
///
/// |___w1__|__w2__|__w3__|__|
///
/// |_____w4_____|_____w5____|
///
ComputeGraphPtr BuildGraphWithMultiInputOutputPrefetch();

void SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, const std::string &batch_label = "");

void SetBatchLabel(NodePtr &node, const std::string &batch_label = "");

void SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size = {1024});

void SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes = {1024});

void SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size,
const std::vector<int64_t> &mem_size = {1024},
const std::vector<int64_t> &ws_bytes = {1024},
const std::string &batch_label = "");

private:
std::string graph_name_;
};
} // namespace ut
} // namespace ge

#endif // GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_

Loading…
Cancel
Save