From: @selfws Reviewed-by: @xchu42 Signed-off-by:tags/v1.3.0
@@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST | |||
"graph/passes/memcpy_addr_async_pass.cc" | |||
"graph/passes/parallel_group_pass.cc" | |||
"graph/passes/set_input_output_offset_pass.cc" | |||
"graph/passes/buffer_pool_memory_pass.cc" | |||
"graph/preprocess/graph_preprocess.cc" | |||
"graph/preprocess/insert_op/ge_aipp_op.cc" | |||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
@@ -407,6 +408,7 @@ set(TRAIN_SRC_LIST | |||
"graph/build/memory/hybrid_mem_assigner.cc" | |||
"graph/build/memory/max_block_mem_assigner.cc" | |||
"graph/build/memory/var_mem_assign_util.cc" | |||
"graph/build/memory/buffer_pool_mem_assigner.cc" | |||
) | |||
set(INFER_SRC_LIST | |||
@@ -617,6 +619,7 @@ set(INFER_SRC_LIST | |||
"graph/passes/memcpy_addr_async_pass.cc" | |||
"graph/passes/set_input_output_offset_pass.cc" | |||
"graph/passes/parallel_group_pass.cc" | |||
"graph/passes/buffer_pool_memory_pass.cc" | |||
"graph/manager/model_manager/event_manager.cc" | |||
"graph/manager/util/rt_context_util.cc" | |||
"graph/manager/util/variable_accelerate_ctrl.cc" | |||
@@ -680,6 +683,7 @@ set(INFER_SRC_LIST | |||
"graph/build/memory/hybrid_mem_assigner.cc" | |||
"graph/build/memory/max_block_mem_assigner.cc" | |||
"graph/build/memory/var_mem_assign_util.cc" | |||
"graph/build/memory/buffer_pool_mem_assigner.cc" | |||
) | |||
if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | |||
@@ -222,6 +222,7 @@ OMG_HOST_SRC_FILES := \ | |||
graph/passes/hccl_group_pass.cc \ | |||
graph/passes/memcpy_addr_async_pass.cc \ | |||
graph/passes/set_input_output_offset_pass.cc \ | |||
graph/passes/buffer_pool_memory_pass.cc \ | |||
OMG_DEVICE_SRC_FILES := $(OMG_HOST_SRC_FILES) | |||
@@ -246,6 +246,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
graph/passes/end_of_sequence_add_control_pass.cc \ | |||
graph/passes/memcpy_addr_async_pass.cc \ | |||
graph/passes/set_input_output_offset_pass.cc \ | |||
graph/passes/buffer_pool_memory_pass.cc \ | |||
graph/preprocess/graph_preprocess.cc \ | |||
graph/preprocess/insert_op/ge_aipp_op.cc \ | |||
graph/preprocess/insert_op/util_insert_aipp_op.cc \ | |||
@@ -1655,6 +1655,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector | |||
bool is_atomic = false; | |||
// If GetBool fail, is_atomic is false. | |||
(void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); | |||
bool is_buffer_pool_mem_supported = (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID)) && | |||
(op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) && (!root_unknown_shape_flag_); | |||
// Allocate memory for the current node and release node memory of the same size in the workspace | |||
GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", | |||
for (auto iter = stream_workspace_blocks_.begin(); iter != stream_workspace_blocks_.end(); | |||
@@ -1694,7 +1696,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector | |||
GE_IF_BOOL_EXEC(!no_need_assign_memory, | |||
no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); | |||
} | |||
no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node)); | |||
no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node) || is_buffer_pool_mem_supported); | |||
if (no_need_assign_memory) { | |||
zero_memory_list_.emplace_back(node, kOutput, i, false); | |||
continue; | |||
@@ -1740,6 +1742,13 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||
const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM); | |||
GE_IF_BOOL_EXEC(op_no_reuse_mem != nullptr, op_no_reuse_mem_str = string(op_no_reuse_mem); | |||
CheckAndGetOpReuseEnv(op_no_reuse_mem_str, op_no_reuse_mem_vec_, op_reuse_env_valid_);); | |||
auto root_graph = GraphUtils::FindRootGraph(compute_graph_); | |||
if (root_graph == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "[Check][RootGraph]Root graph is nullptr, graph:%s.", compute_graph_->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Root graph is nullptr, graph:%s.", compute_graph_->GetName().c_str()); | |||
return; | |||
} | |||
root_unknown_shape_flag_ = root_graph->GetGraphUnknownFlag(); | |||
for (NodePtr &n : compute_graph_->GetAllNodes()) { | |||
auto node_op_desc = n->GetOpDesc(); | |||
@@ -494,6 +494,8 @@ class BlockMemAssigner : public MemAssigner { | |||
/// @ [stream2][nodeid] | |||
/// | |||
DependStreamLife total_node_depend_stream_life_; | |||
bool root_unknown_shape_flag_ = false; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_ |
@@ -0,0 +1,234 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/build/memory/buffer_pool_mem_assigner.h" | |||
#include "graph/common/omg_util.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "framework/common/util.h" | |||
#include "graph/compute_graph.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "common/math/math_util.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
namespace ge { | |||
namespace { | |||
const size_t kBufferPoolNodeMemInfoLength = 2; | |||
const uint32_t kBufferPoolNodeOutputSizeIndex = 0; | |||
const uint32_t kBufferPoolNodeOutputOffsetIndex = 1; | |||
} // namespace | |||
Status BufferPoolMemAssigner::Assign() { | |||
if (compute_graph_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr"); | |||
REPORT_INNER_ERROR("E19999", "Input graph is nullptr"); | |||
return PARAM_INVALID; | |||
} | |||
Status ret = InitAssigner(compute_graph_); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Init][Assigner]Graph:%s.", compute_graph_->GetName().c_str()); | |||
return FAILED; | |||
} | |||
ret = AssignOutput(); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Assign][Output]Graph:%s.", compute_graph_->GetName().c_str()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemAssigner::GetOutputMemoryType(const NodePtr &node, size_t idx, int64_t &memory_type) { | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
memory_type = RT_MEMORY_HBM; | |||
std::vector<int64_t> type_list; | |||
bool has_mem_type = ge::AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list); | |||
if (has_mem_type && (type_list.size() != node->GetOpDesc()->GetOutputsSize() || idx >= type_list.size())) { | |||
GELOGE(PARAM_INVALID, "[Check][OutputParam]Output param invalid, output size:%zu, mem type size:%zu, index:%zu.", | |||
node->GetOpDesc()->GetOutputsSize(), type_list.size(), idx); | |||
REPORT_INNER_ERROR("E19999", "Output param invalid, output size:%zu, mem type size:%zu, index:%zu.", | |||
node->GetOpDesc()->GetOutputsSize(), type_list.size(), idx); | |||
return PARAM_INVALID; | |||
} | |||
memory_type = has_mem_type ? type_list[idx] : RT_MEMORY_HBM; | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemAssigner::InitAssigner(const ComputeGraphPtr &graph) { | |||
for (const NodePtr &node : graph->GetAllNodes()) { | |||
int64_t buffer_pool_id = 0; | |||
int64_t buffer_pool_size = 0; | |||
bool get_attr = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id); | |||
get_attr = get_attr && (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size)); | |||
if (get_attr) { | |||
std::string batch_label; | |||
(void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | |||
buffer_pool_nodes_[batch_label][buffer_pool_id].emplace_back(node); | |||
auto iter = buffer_pool_size_[batch_label].find(buffer_pool_id); | |||
if (iter == buffer_pool_size_[batch_label].end()) { | |||
buffer_pool_size_[batch_label][buffer_pool_id] = buffer_pool_size; | |||
} | |||
Status ret = InitMemOffsetBase(node); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Init][MemOffsetBase]Batch label:%s.", batch_label.c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to init offset base, batch label:%s.", batch_label.c_str()); | |||
return ret; | |||
} | |||
} | |||
} | |||
int64_t max_size = 0; | |||
for (const auto &iter : buffer_pool_size_) { | |||
std::string batch_label = iter.first; | |||
int64_t batch_offset = mem_offset_base_; | |||
for (const auto &buffer_pool : iter.second) { | |||
int64_t buffer_pool_id = buffer_pool.first; | |||
int64_t buffer_pool_size = buffer_pool.second; | |||
buffer_pool_offset_base_[batch_label][buffer_pool_id] = batch_offset; | |||
FMK_INT64_ADDCHECK(buffer_pool_size, kBufferPoolMemAlignSize); | |||
AlignMemSize(buffer_pool_size, kBufferPoolMemAlignSize); | |||
FMK_INT64_ADDCHECK(batch_offset, (buffer_pool_size + kBufferPoolMemAlignSize)); | |||
batch_offset += (buffer_pool_size + kBufferPoolMemAlignSize); | |||
} | |||
int64_t batch_mem_size = batch_offset - mem_offset_base_; | |||
GELOGI("[Init][Assigner]Get batch mem size, batch label:%s, mem size:%ld.", batch_label.c_str(), batch_mem_size); | |||
if (max_size < batch_mem_size) { | |||
max_size = batch_mem_size; | |||
} | |||
} | |||
FMK_INT64_ADDCHECK(mem_offset_base_, max_size); | |||
mem_offset_ = static_cast<size_t>(mem_offset_base_ + max_size); | |||
GELOGI("[Init][Assigner]Init buffer pool mem assigner successfully, " | |||
"mem type:%ld, mem offset base:%ld, mem offset:%zu.", mem_type_, mem_offset_base_, mem_offset_); | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemAssigner::InitMemOffsetBase(const NodePtr &node) { | |||
int64_t mem_type; | |||
Status ret = GetOutputMemoryType(node, static_cast<size_t>(kBufferPoolNodeOutIndex), mem_type); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Get][MemType]Node:%s, index:%u.", node->GetName().c_str(), kBufferPoolNodeOutIndex); | |||
REPORT_INNER_ERROR("E19999", "Failed to get output memory type, node:%s, index:%u.", | |||
node->GetName().c_str(), kBufferPoolNodeOutIndex); | |||
return ret; | |||
} | |||
if (mem_type_ != mem_type && init_offset_base_) { | |||
GELOGE(PARAM_INVALID, "[Check][MemType]The memory type of all buffer pool nodes must be the same, node:%s, " | |||
"required:%ld, actually: %ld", node->GetName().c_str(), mem_type_, mem_type); | |||
REPORT_INNER_ERROR("E19999", "The memory type of all buffer pool nodes must be the same, node:%s, " | |||
"required:%ld, actually: %ld", node->GetName().c_str(), mem_type_, mem_type); | |||
return PARAM_INVALID; | |||
} | |||
if (!init_offset_base_) { | |||
auto iter = mem_type_to_offset_.find(mem_type); | |||
if (iter == mem_type_to_offset_.end()) { | |||
GELOGE(PARAM_INVALID, "[Check][MemType]Memory type is not supported, node:%s, mem type:%ld.", | |||
node->GetName().c_str(), mem_type); | |||
REPORT_INNER_ERROR("E19999", "Memory type is not supported, node:%s, mem type:%ld.", | |||
node->GetName().c_str(), mem_type); | |||
return PARAM_INVALID; | |||
} | |||
mem_offset_base_ = static_cast<int64_t>(iter->second); | |||
FMK_INT64_ADDCHECK(mem_offset_base_, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize)); | |||
AlignMemSize(mem_offset_base_, kBufferPoolMemAlignSize); | |||
// The HCOM nodes may access the previous 512 bytes. | |||
mem_offset_base_ += kBufferPoolMemAlignSize; | |||
mem_type_ = mem_type; | |||
init_offset_base_ = true; | |||
GELOGI("[Init][MemOffsetBase]Init offset base:%ld, memory type:%ld", mem_offset_base_, mem_type); | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemAssigner::AssignOutput() { | |||
for (auto &batch_pool_nodes_map : buffer_pool_nodes_) { | |||
std::string batch_label = batch_pool_nodes_map.first; | |||
for (auto &pool_nodes_map : batch_pool_nodes_map.second) { | |||
int64_t buffer_pool_id = pool_nodes_map.first; | |||
auto iter_buffer_id_size = buffer_pool_size_[batch_label].find(buffer_pool_id); | |||
if (iter_buffer_id_size == buffer_pool_size_[batch_label].end()) { | |||
GELOGE(INTERNAL_ERROR, "[Get][BufferPoolSize]Pool id:%ld.", buffer_pool_id); | |||
REPORT_INNER_ERROR("E19999", "Failed to get buffer pool size, pool id:%ld.", buffer_pool_id); | |||
return INTERNAL_ERROR; | |||
} | |||
auto iter_buffer_id_offset = buffer_pool_offset_base_[batch_label].find(buffer_pool_id); | |||
if (iter_buffer_id_offset == buffer_pool_offset_base_[batch_label].end()) { | |||
GELOGE(INTERNAL_ERROR, "[Get][BufferPoolBaseOffset]Pool id:%ld.", buffer_pool_id); | |||
REPORT_INNER_ERROR("E19999", "Failed to get buffer pool base offset, pool id:%ld.", buffer_pool_id); | |||
return INTERNAL_ERROR; | |||
} | |||
int64_t buffer_pool_size = iter_buffer_id_size->second; | |||
int64_t output_offset_base = iter_buffer_id_offset->second; | |||
Status ret = AssignOutputInOneBufferPool(batch_label, output_offset_base, pool_nodes_map.second); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Assign][OneBufferPool]Batch label:%s, pool id:%ld, pool size:%ld, offset base:%ld.", | |||
batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base); | |||
REPORT_INNER_ERROR("E19999", "Failed to assign output memory, batch label:%s, " | |||
"pool id:%ld, pool size:%ld, offset base:%ld.", | |||
batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base); | |||
return ret; | |||
} | |||
GELOGI("[Assign][Output]Assign output successfully, batch label:%s, pool id:%ld, pool size:%ld, offset base:%ld.", | |||
batch_label.c_str(), buffer_pool_id, buffer_pool_size, output_offset_base); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemAssigner::AssignOutputInOneBufferPool(const std::string &batch_label, | |||
int64_t output_offset_base, | |||
const std::vector<NodePtr> &buffer_pool_nodes) { | |||
for (const NodePtr &node : buffer_pool_nodes) { | |||
int64_t output_size = 0; | |||
Status ret = GetMemorySize(node, output_size); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Get][MemSize]Node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str()); | |||
return ret; | |||
} | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
vector<int64_t> memory_size_and_offset; | |||
bool get_attr = AttrUtils::GetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, memory_size_and_offset); | |||
if (!get_attr || memory_size_and_offset.size() != kBufferPoolNodeMemInfoLength) { | |||
GELOGE(PARAM_INVALID, "[Get][Attr]Node:%s, mem info size:%zu, required size:%zu.", | |||
node->GetName().c_str(), memory_size_and_offset.size(), kBufferPoolNodeMemInfoLength); | |||
REPORT_INNER_ERROR("E19999", "Failed to get pool node memory info, node:%s, info size:%zu, required size:%zu.", | |||
node->GetName().c_str(), memory_size_and_offset.size(), kBufferPoolNodeMemInfoLength); | |||
return PARAM_INVALID; | |||
} | |||
if (output_size != memory_size_and_offset[kBufferPoolNodeOutputSizeIndex]) { | |||
GELOGE(PARAM_INVALID, "[Check][MemSize]Something wrong with memory size, pre size:%ld, curr size:%ld, node:%s.", | |||
memory_size_and_offset[kBufferPoolNodeOutputSizeIndex], output_size, node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Something wrong with memory size, pre size:%ld, curr size:%ld, node:%s.", | |||
memory_size_and_offset[kBufferPoolNodeOutputSizeIndex], output_size, node->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
int64_t logical_offset = memory_size_and_offset[kBufferPoolNodeOutputOffsetIndex]; | |||
vector<int64_t> output_list = {(output_offset_base + logical_offset)}; | |||
op_desc->SetOutputOffset(output_list); | |||
// log for IMAS tools | |||
GELOGI("[IMAS]Set %s name[%s] optype[%s] %s[%u] offset to [%ld] streamid[%ld] memtype[%ld] " | |||
"size[%zu] realsize[%zu] noalignsize[%zu] life time begin[%d] life time end[%d] " | |||
"child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", | |||
compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
"output", kBufferPoolNodeOutIndex, output_list[kBufferPoolNodeOutIndex], op_desc->GetStreamId(), mem_type_, | |||
static_cast<size_t>(output_size), static_cast<size_t>(output_size), static_cast<size_t>(output_size), | |||
0, 0, 0, 0, 0, 0, 0, 0, batch_label.c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -0,0 +1,83 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_ | |||
#define GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_ | |||
#include <vector> | |||
#include <map> | |||
#include <unordered_map> | |||
#include "graph/build/memory/mem_assigner.h" | |||
#include "runtime/mem.h" | |||
namespace ge { | |||
class BufferPoolMemAssigner : public MemAssigner { | |||
public: | |||
BufferPoolMemAssigner(ComputeGraphPtr compute_graph, const std::map<int64_t, size_t> &mem_type_to_offset) | |||
: MemAssigner(), compute_graph_(compute_graph), | |||
mem_type_(0), | |||
mem_offset_(0), | |||
mem_offset_base_(0), | |||
init_offset_base_(false), | |||
mem_type_to_offset_(mem_type_to_offset) {} | |||
BufferPoolMemAssigner(const BufferPoolMemAssigner &) = delete; | |||
BufferPoolMemAssigner &operator=(const BufferPoolMemAssigner &) = delete; | |||
~BufferPoolMemAssigner() override = default; | |||
Status Assign() override; | |||
size_t GetMemOffset() const { return mem_offset_; } | |||
int64_t GetMemType() const { return mem_type_; } | |||
private: | |||
static Status GetOutputMemoryType(const NodePtr &node, size_t idx, int64_t &memory_type); | |||
Status InitAssigner(const ComputeGraphPtr &graph); | |||
Status InitMemOffsetBase(const NodePtr &node); | |||
Status AssignOutput(); | |||
Status AssignOutputInOneBufferPool(const std::string &batch_label, | |||
int64_t output_offset_base, | |||
const std::vector<NodePtr> &buffer_pool_nodes); | |||
ComputeGraphPtr compute_graph_; | |||
int64_t mem_type_; | |||
size_t mem_offset_; | |||
int64_t mem_offset_base_; | |||
bool init_offset_base_; | |||
std::map<int64_t, size_t> mem_type_to_offset_; | |||
// Use map to ensure that each visit is in the order of pool id | |||
std::unordered_map<std::string, std::map<int64_t, std::vector<NodePtr>>> buffer_pool_nodes_; | |||
// Use map to ensure that each visit is in the order of pool id | |||
std::unordered_map<std::string, std::map<int64_t, int64_t>> buffer_pool_size_; | |||
std::unordered_map<std::string, std::unordered_map<int64_t, int64_t>> buffer_pool_offset_base_; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_BUILD_MEMORY_BUFFER_POOL_MEM_ASSIGNER_H_ |
@@ -30,6 +30,7 @@ | |||
#include "graph/manager/graph_var_manager.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/build/memory/buffer_pool_mem_assigner.h" | |||
namespace { | |||
const int kAllInputAddrIsAtomic = -1; | |||
@@ -231,6 +232,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map<int64_t, size | |||
GE_CHK_STATUS_RET(ReAssignContinuousMemory(is_loop_graph), "ReAssignContinuousMemory Failed!"); | |||
GE_CHK_STATUS_RET(ReAssignAtomicMemory(is_loop_graph), "ReAssignAtomicMemory Failed!"); | |||
GE_CHK_STATUS_RET(AssignBufferPoolMemory(), "AssignBufferPoolMemory Failed!"); | |||
size_t total_mem_offset = 0; | |||
for (auto pair : memory_offset_) { | |||
@@ -1735,4 +1737,54 @@ ge::Status GraphMemoryAssigner::AssignContinuousInputMemoryWithAtomicProcess(con | |||
return ge::SUCCESS; | |||
} | |||
Status GraphMemoryAssigner::AssignBufferPoolMemory() { | |||
auto is_buffer_pool_mem_enable = [] (const ComputeGraphPtr &graph) -> bool { | |||
for (NodePtr &node : graph->GetAllNodes()) { | |||
auto op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
continue; | |||
} | |||
bool has_attrs = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE); | |||
if (has_attrs) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
}; | |||
auto root_graph = GraphUtils::FindRootGraph(compute_graph_); | |||
GE_CHECK_NOTNULL(root_graph); | |||
if (root_graph->GetGraphUnknownFlag()) { | |||
GELOGI("[Check][Enable]Unknown root graph does not support buffer pool memory, graph:%s.", | |||
compute_graph_->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
if (!is_buffer_pool_mem_enable(compute_graph_)) { | |||
GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", compute_graph_->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
map<int64_t, size_t> mem_type_to_offset; | |||
for (const auto &pair : memory_offset_) { | |||
mem_type_to_offset[pair.first] = pair.second.mem_offset_; | |||
} | |||
BufferPoolMemAssigner buffer_pool_mem_assigner(compute_graph_, mem_type_to_offset); | |||
Status status = buffer_pool_mem_assigner.Assign(); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Assign][BufferPoolMem]Graph:%s.", compute_graph_->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to assign buffer pool memory, graph:%s.", compute_graph_->GetName().c_str()); | |||
return status; | |||
} | |||
int64_t mem_type = buffer_pool_mem_assigner.GetMemType(); | |||
auto iter = memory_offset_.find(mem_type); | |||
if (iter == memory_offset_.end()) { | |||
GELOGE(FAILED, "[Check][MemType]Memory type is not supported, graph:%s, mem type:%ld.", | |||
compute_graph_->GetName().c_str(), mem_type); | |||
REPORT_INNER_ERROR("E19999", "Memory type is not supported, graph:%s, mem type:%ld.", | |||
compute_graph_->GetName().c_str(), mem_type); | |||
return FAILED; | |||
} | |||
iter->second.mem_offset_ = buffer_pool_mem_assigner.GetMemOffset(); | |||
GELOGI("[Assign][BufferPoolMem]Assign buffer pool memory successfully, graph:%s, mem type:%ld, mem offset:%zu.", | |||
compute_graph_->GetName().c_str(), mem_type, buffer_pool_mem_assigner.GetMemOffset()); | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -188,6 +188,8 @@ class GraphMemoryAssigner { | |||
void PrintMemoryOffset(); | |||
Status AssignBufferPoolMemory(); | |||
MemoryOffsetMap memory_offset_; | |||
ge::ComputeGraphPtr compute_graph_; | |||
HybridMemAssignerPtr mem_assigner_; | |||
@@ -8,6 +8,7 @@ local_lib_src_files := memory_assigner.cc \ | |||
hybrid_mem_assigner.cc \ | |||
max_block_mem_assigner.cc \ | |||
var_mem_assign_util.cc \ | |||
buffer_pool_mem_assigner.cc \ | |||
local_lib_inc_path := ${LOCAL_PATH} \ | |||
${TOPDIR}inc \ | |||
@@ -18,6 +18,7 @@ | |||
#include "common/util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/common/omg_util.h" | |||
namespace ge { | |||
RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); } | |||
@@ -88,9 +89,11 @@ Status RunContextUtil::CreateRtModelResources(uint32_t stream_num, uint32_t even | |||
} | |||
// Create rt event | |||
uint32_t create_flag = static_cast<uint32_t>((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG : | |||
RT_EVENT_DEFAULT); | |||
for (uint32_t i = 0; i < event_num; ++i) { | |||
rtEvent_t event = nullptr; | |||
rt_ret = rtEventCreate(&event); | |||
rt_ret = rtEventCreateWithFlag(&event, create_flag); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "call rtEventCreate fail, ret:%d, index:%u, when %s", | |||
static_cast<int>(rt_ret), i, __FUNCTION__); | |||
@@ -27,6 +27,8 @@ | |||
#include "graph/ge_context.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "init/gelib.h" | |||
#include "common/string_util.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
using std::map; | |||
using std::set; | |||
@@ -38,6 +40,13 @@ const int64_t kTaskNumPerNormalNode = 3; | |||
const int64_t kTaskNumPerHcclNode = 245; | |||
const char *const kTrueStr = "true"; | |||
const char *const kFalseStr = "false"; | |||
const size_t kEventMultiplexingItemCount = 3; | |||
const size_t kKeyWordIndex = 0; | |||
const size_t kNodeNameIndex = 1; | |||
const size_t kEventIdIndex = 2; | |||
const char *const kSend = "SendTo"; | |||
const char *const kRecv = "RecvFrom"; | |||
const char kDelim = ';'; | |||
inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &continuous_stream_label) { | |||
if (ge::AttrUtils::GetStr(op_desc, ge::ATTR_NAME_CONTINUOUS_STREAM_LABEL, continuous_stream_label)) { | |||
@@ -52,6 +61,97 @@ bool IsHcclOp(const string &op_type) { | |||
ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); | |||
return hccl_op_types.find(op_type) != hccl_op_types.end(); | |||
} | |||
ge::Status ParseNodeEventMultiplexing(const ge::NodePtr &node, | |||
const std::vector<std::string> &raw_event_multiplexing, | |||
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_send, | |||
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_recv) { | |||
GE_CHECK_NOTNULL(node); | |||
for (const auto &str : raw_event_multiplexing) { | |||
std::vector<std::string> ele = ge::StringUtils::Split(str, kDelim); | |||
if (ele.size() != kEventMultiplexingItemCount) { | |||
GELOGE(ge::PARAM_INVALID, "[Check][RawMultiplexing]Size error, node:%s, require size:%zu, actually:%zu.", | |||
node->GetName().c_str(), kEventMultiplexingItemCount, ele.size()); | |||
REPORT_INNER_ERROR("E19999", "Raw event multiplexing is invalid, node:%s, require size:%zu, actually:%zu.", | |||
node->GetName().c_str(), kEventMultiplexingItemCount, ele.size()); | |||
return ge::PARAM_INVALID; | |||
} | |||
int value; | |||
try { | |||
value = std::stoi(ele[kEventIdIndex]); | |||
} catch (std::invalid_argument &) { | |||
GELOGE(ge::PARAM_INVALID, "[Throw][Exception]Event id is invalid, node:%s, raw:%s.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str()); | |||
REPORT_INNER_ERROR("E19999", "Event id is invalid, node:%s, raw:%s.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str()); | |||
return ge::PARAM_INVALID; | |||
} catch (std::out_of_range &) { | |||
GELOGE(ge::PARAM_INVALID, "[Throw][Exception]Event id is out of range, node:%s, raw:%s.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str()); | |||
REPORT_INNER_ERROR("E19999", "Event id is out of range, node:%s, raw:%s.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str()); | |||
return ge::PARAM_INVALID; | |||
} | |||
if (value < 0) { | |||
GELOGE(ge::PARAM_INVALID, "[Check][EventId]Event id is out of range, node:%s, raw:%s, value:%d.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str(), value); | |||
REPORT_INNER_ERROR("E19999", "Event id is out of range, node:%s, raw:%s, value:%d.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str(), value); | |||
return ge::PARAM_INVALID; | |||
} | |||
if (ele[kKeyWordIndex] == kSend) { | |||
node_to_send[node].emplace_back(std::make_pair(ele[kNodeNameIndex], static_cast<uint32_t>(value))); | |||
} else if (ele[kKeyWordIndex] == kRecv) { | |||
node_to_recv[node].emplace_back(std::make_pair(ele[kNodeNameIndex], static_cast<uint32_t>(value))); | |||
} else { | |||
GELOGE(ge::PARAM_INVALID, "[Check][KeyWord]Key word is not supported, node:%s, key:%s.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str()); | |||
REPORT_INNER_ERROR("E19999", "Key word is not supported, node:%s, key:%s.", | |||
node->GetName().c_str(), ele[kEventIdIndex].c_str()); | |||
return ge::PARAM_INVALID; | |||
} | |||
} | |||
return ge::SUCCESS; | |||
} | |||
ge::Status ParseAllNodeEventMultiplexing(const ge::ComputeGraphPtr &graph, | |||
std::unordered_map<std::string, ge::NodePtr> &name_to_node_map, | |||
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_send, | |||
std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_recv) { | |||
for (const auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | |||
ge::OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
name_to_node_map.insert({node->GetName(), node}); | |||
std::vector<std::string> raw_event_multiplexing; | |||
if (!(op_desc->HasAttr(ge::ATTR_NAME_EVENT_MULTIPLEXING))) { | |||
continue; | |||
} | |||
bool get_attr = ge::AttrUtils::GetListStr(op_desc, ge::ATTR_NAME_EVENT_MULTIPLEXING, raw_event_multiplexing); | |||
if (!get_attr) { | |||
GELOGE(ge::PARAM_INVALID, "[Get][Attr]Node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to get raw event multiplexing, node:%s.", node->GetName().c_str()); | |||
return ge::PARAM_INVALID; | |||
} | |||
auto parse_ret = ParseNodeEventMultiplexing(node, raw_event_multiplexing, node_to_send, node_to_recv); | |||
if (parse_ret != ge::SUCCESS) { | |||
GELOGE(parse_ret, "[Parse][Eventmultiplexing]Node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to parse node event multiplexing, node:%s.", node->GetName().c_str()); | |||
return parse_ret; | |||
} | |||
} | |||
return ge::SUCCESS; | |||
} | |||
std::vector<uint32_t> GetIntersection(std::vector<uint32_t> &a, std::vector<uint32_t> &b) { | |||
std::unordered_set<uint32_t> ele_of_a(a.begin(), a.end()); | |||
std::vector<uint32_t> res; | |||
for (auto &ele : b) { | |||
if (ele_of_a.count(ele) > 0) { | |||
res.emplace_back(ele); | |||
} | |||
} | |||
return res; | |||
} | |||
} // namespace | |||
namespace ge { | |||
@@ -150,6 +250,12 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu | |||
return status; | |||
} | |||
status = RefreshEventsWithReuse(); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Refresh][Events]RefreshEventsWithReuse failed!"); | |||
return status; | |||
} | |||
status = InsertSyncEventNodes(); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "InsertSyncEventNode failed!"); | |||
@@ -1161,6 +1267,94 @@ Status StreamAllocator::CheckStreamActived() const { | |||
return SUCCESS; | |||
} | |||
Status StreamAllocator::ReuseEvent(bool send_to, | |||
const std::unordered_map<std::string, ge::NodePtr> &name_to_node_map, | |||
const std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_event_id) { | |||
for (const auto &node_event_id : node_to_event_id) { | |||
ge::NodePtr curr_node = node_event_id.first; | |||
NodePtr send_node = send_to ? curr_node : nullptr; | |||
NodePtr recv_node = send_to ? nullptr : curr_node; | |||
for (const auto &event_pair : node_event_id.second) { | |||
auto peer_node_iter = name_to_node_map.find(event_pair.first); | |||
if (peer_node_iter == name_to_node_map.end()) { | |||
GELOGE(PARAM_INVALID, "[Get][Node]Name:%s.", event_pair.first.c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to find node, name:%s.", event_pair.first.c_str()); | |||
return PARAM_INVALID; | |||
} | |||
recv_node = send_to ? peer_node_iter->second : recv_node; | |||
send_node = send_to ? send_node : peer_node_iter->second; | |||
GE_CHECK_NOTNULL(send_node); | |||
GE_CHECK_NOTNULL(recv_node); | |||
auto event_id = GetIntersection(node_to_send_events_[send_node], node_to_recv_events_[recv_node]); | |||
uint32_t new_event = event_pair.second + event_num_; | |||
if (event_id.empty()) { | |||
GELOGI("[Check][Optimized]Send:%s, recv:%s.", send_node->GetName().c_str(), recv_node->GetName().c_str()); | |||
continue; | |||
} else if (event_id.size() != 1) { | |||
GELOGW("[Check][Event]More than one event are found between %s and %s, event num:%zu.", | |||
send_node->GetName().c_str(), recv_node->GetName().c_str(), event_id.size()); | |||
} | |||
uint32_t old_event = event_id[0]; | |||
auto reuse_event_id = [] (vector<uint32_t> &event_list, uint32_t old_event, uint32_t new_event) -> void { | |||
event_list.erase(std::remove(event_list.begin(), event_list.end(), old_event), event_list.end()); | |||
event_list.push_back(new_event); | |||
return; | |||
}; | |||
reuse_event_id(node_to_send_events_[send_node], old_event, new_event); | |||
reuse_event_id(node_to_recv_events_[recv_node], old_event, new_event); | |||
GELOGI("[Reuse][Event]Replace event successfully, send node:%s, recv node:%s, old id:%u, new id:%u.", | |||
send_node->GetName().c_str(), recv_node->GetName().c_str(), old_event, new_event); | |||
} | |||
} | |||
return ge::SUCCESS; | |||
} | |||
// Refresh events to reuse events | |||
Status StreamAllocator::RefreshEventsWithReuse() { | |||
GELOGI("[Refresh][Events]Refresh events with reuse, stream num:%ld, original event num:%u.", stream_num_, event_num_); | |||
if (event_num_ <= kEventReuseThreshold) { | |||
GELOGI("[Check][ReuseThreshold]Event used num is %u, less than %u, skip reuse.", | |||
event_num_, kEventReuseThreshold); | |||
return SUCCESS; | |||
} | |||
std::unordered_map<std::string, NodePtr> name_to_node_map; | |||
std::unordered_map<NodePtr, std::vector<std::pair<std::string, uint32_t>>> node_to_send; | |||
std::unordered_map<NodePtr, std::vector<std::pair<std::string, uint32_t>>> node_to_recv; | |||
Status ret = ParseAllNodeEventMultiplexing(whole_graph_, name_to_node_map, node_to_send, node_to_recv); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Parse][AllNodeEventMultiplexing]Graph:%s.", whole_graph_->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to parse all node event multiplexing, graph:%s.", | |||
whole_graph_->GetName().c_str()); | |||
return ret; | |||
} | |||
if (node_to_send.empty() && node_to_recv.empty()) { | |||
return SUCCESS; | |||
} | |||
ret = ReuseEvent(true, name_to_node_map, node_to_send); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Reuse][Event]Phase:Send, graph:%s.", whole_graph_->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to reuse event, phase:Send, graph:%s.", whole_graph_->GetName().c_str()); | |||
return ret; | |||
} | |||
ret = ReuseEvent(false, name_to_node_map, node_to_recv); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Reuse][Event]Phase:Recv, graph:%s.", whole_graph_->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to reuse event, phase:Recv, graph:%s.", whole_graph_->GetName().c_str()); | |||
return ret; | |||
} | |||
Status status = RefreshContinuousEvents(); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Refresh][ContinuousEvents]Graph:%s.", whole_graph_->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to refresh continuous events, graph:%s.", whole_graph_->GetName().c_str()); | |||
return status; | |||
} | |||
GELOGI("[Refresh][Events]RefreshEventsWithReuse successfully, event num:%u.", event_num_); | |||
return SUCCESS; | |||
} | |||
// Refresh events to continuous events | |||
Status StreamAllocator::RefreshContinuousEvents() { | |||
// Establish a mapping relationship from old to new event id | |||
@@ -1168,8 +1362,10 @@ Status StreamAllocator::RefreshContinuousEvents() { | |||
uint32_t new_event_id = 0; | |||
for (const auto &one_pair : node_to_send_events_) { | |||
for (const auto &event_id : one_pair.second) { | |||
old_to_new_events[event_id] = new_event_id; | |||
new_event_id++; | |||
if (old_to_new_events.find(event_id) == old_to_new_events.end()) { | |||
old_to_new_events[event_id] = new_event_id; | |||
new_event_id++; | |||
} | |||
} | |||
} | |||
@@ -1208,6 +1404,7 @@ Status StreamAllocator::RefreshContinuousEvents() { | |||
// Insert the real send/recv node in the graph | |||
Status StreamAllocator::InsertSyncEventNodes() { | |||
unordered_map<string, uint32_t> sync_event_name; | |||
for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | |||
// Add the node corresponding to the recv event | |||
vector<uint32_t> recv_event_id_list; | |||
@@ -1217,6 +1414,13 @@ Status StreamAllocator::InsertSyncEventNodes() { | |||
GE_CHECK_NOTNULL(node->GetOutControlAnchor()); | |||
for (auto &event_id : recv_event_id_list) { | |||
string recv_node_name = whole_graph_->GetName() + "_Recv_" + to_string(event_id); | |||
auto iter = sync_event_name.find(recv_node_name); | |||
if (iter == sync_event_name.end()) { | |||
sync_event_name[recv_node_name] = 1; | |||
} else { | |||
recv_node_name = recv_node_name + "_Reuse_" + to_string(iter->second); | |||
++(iter->second); | |||
} | |||
OpDescPtr op_desc_ptr = MakeShared<OpDesc>(recv_node_name, RECV); | |||
GE_CHECK_NOTNULL(op_desc_ptr); | |||
@@ -1251,6 +1455,13 @@ Status StreamAllocator::InsertSyncEventNodes() { | |||
for (auto &event_id : send_event_id_list) { | |||
string send_node_name = whole_graph_->GetName() + "_Send_" + to_string(event_id); | |||
auto iter = sync_event_name.find(send_node_name); | |||
if (iter == sync_event_name.end()) { | |||
sync_event_name[send_node_name] = 1; | |||
} else { | |||
send_node_name = send_node_name + "_Reuse_" + to_string(iter->second); | |||
++(iter->second); | |||
} | |||
OpDescPtr op_desc_ptr = MakeShared<OpDesc>(send_node_name, SEND); | |||
GE_CHECK_NOTNULL(op_desc_ptr); | |||
@@ -1300,12 +1511,16 @@ void StreamAllocator::DumpEvents() { | |||
GELOGD("After RefreshRealStream: stream %ld.", stream_id); | |||
for (const auto &node : one_pair.second) { | |||
if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
continue; | |||
} | |||
string send_event_str; | |||
for (const auto &send_event_id : node_to_send_events_[node]) { | |||
send_event_str += " " + to_string(send_event_id); | |||
} | |||
if (!send_event_str.empty()) { | |||
GELOGI("node: %s, send events: %s", node->GetName().c_str(), send_event_str.c_str()); | |||
GELOGI("node: %s, id: %ld, stream id :%ld, send events: %s.", node->GetName().c_str(), | |||
node->GetOpDesc()->GetId(), node->GetOpDesc()->GetStreamId(), send_event_str.c_str()); | |||
} | |||
string recv_event_str; | |||
@@ -1313,7 +1528,8 @@ void StreamAllocator::DumpEvents() { | |||
recv_event_str += " " + to_string(recv_event_id); | |||
} | |||
if (!recv_event_str.empty()) { | |||
GELOGI("node: %s, recv events: %s", node->GetName().c_str(), recv_event_str.c_str()); | |||
GELOGI("node: %s, id: %ld, stream id :%ld, recv events: %s.", node->GetName().c_str(), | |||
node->GetOpDesc()->GetId(), node->GetOpDesc()->GetStreamId(), recv_event_str.c_str()); | |||
} | |||
} | |||
} | |||
@@ -71,6 +71,10 @@ class StreamAllocator { | |||
Status SetActiveStreamsForLoop(); | |||
Status CheckStreamActived() const; | |||
Status ReuseEvent(bool send_to, | |||
const std::unordered_map<std::string, ge::NodePtr> &name_to_node_map, | |||
const std::unordered_map<ge::NodePtr, std::vector<std::pair<std::string, uint32_t>>> &node_to_event_id); | |||
Status RefreshEventsWithReuse(); | |||
Status RefreshContinuousEvents(); | |||
Status InsertSyncEventNodes(); | |||
@@ -21,6 +21,8 @@ | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "common/math/math_util.h" | |||
namespace ge { | |||
/// | |||
@@ -204,4 +206,42 @@ Status SetNextIteration(const ge::NodePtr &node, const std::string &next) { | |||
return SUCCESS; | |||
} | |||
/// | |||
/// @brief Align the memory | |||
/// @param [in/out] memory size | |||
/// @param [in] alinment | |||
/// @return void | |||
/// | |||
void AlignMemSize(int64_t &mem_size, int64_t align_size) { | |||
if (mem_size <= 0) { | |||
return; | |||
} | |||
mem_size = (mem_size + align_size - 1) / align_size * align_size; | |||
} | |||
/// | |||
/// @brief Get memory size from tensor desc | |||
/// @param [in] node | |||
/// @param [out] memory size | |||
/// @return Status | |||
/// | |||
Status GetMemorySize(const NodePtr &node, int64_t &output_size) { | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
auto output_op_desc = node->GetOpDesc()->GetOutputDescPtr(kBufferPoolNodeOutIndex); | |||
GE_CHECK_NOTNULL(output_op_desc); | |||
int64_t size = 0; | |||
auto ret = ge::TensorUtils::GetSize(*output_op_desc, size); | |||
if (ret != ge::GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Get][Size]Node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
FMK_INT64_ADDCHECK(size, kBufferPoolMemAlignSize); | |||
AlignMemSize(size, kBufferPoolMemAlignSize); | |||
// The HCOM operator requires an additional 512 bytes before and after | |||
FMK_INT64_ADDCHECK(size, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize)); | |||
output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize; | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -27,6 +27,11 @@ | |||
#include "graph/node.h" | |||
namespace ge { | |||
namespace { | |||
const int64_t kBufferPoolMemAlignSize = 512; | |||
const uint32_t kBufferPoolNodeOutIndex = 0; | |||
const uint32_t kEventReuseThreshold = 65500; | |||
} // namespace | |||
/// | |||
/// @brief get the Original Type of FrameworkOp | |||
/// @param [in] node | |||
@@ -96,6 +101,22 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node); | |||
/// @return Status | |||
/// | |||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next); | |||
/// | |||
/// @brief Align the memory | |||
/// @param [in/out] memory size | |||
/// @param [in] alinment | |||
/// @return void | |||
/// | |||
void AlignMemSize(int64_t &mem_size, int64_t align_size); | |||
/// | |||
/// @brief Get memory size from tensor desc | |||
/// @param [in] node | |||
/// @param [out] memory size | |||
/// @return Status | |||
/// | |||
Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||
} // namespace ge | |||
#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ |
@@ -60,6 +60,7 @@ | |||
#include "securec.h" | |||
#include "graph/common/local_context.h" | |||
#include "common/formats/utils/formats_trans_utils.h" | |||
#include "graph/common/omg_util.h" | |||
// create std::thread, catch exceptions using try/catch | |||
#define CREATE_STD_THREAD(thread_id, func, args) \ | |||
@@ -664,9 +665,12 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||
GELOGI("Logical stream index:%u, stream:%p, rtstream: %d.", i, stream, rt_stream_id); | |||
} | |||
for (uint32_t i = 0; i < EventNum(); i++) { | |||
rtEvent_t rt_event; | |||
GE_CHK_RT_RET(rtEventCreate(&rt_event)); | |||
uint32_t event_num = EventNum(); | |||
uint32_t create_flag = static_cast<uint32_t>((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG : | |||
RT_EVENT_DEFAULT); | |||
for (uint32_t i = 0; i < event_num; ++i) { | |||
rtEvent_t rt_event = nullptr; | |||
GE_CHK_RT_RET(rtEventCreateWithFlag(&rt_event, create_flag)); | |||
event_list_.push_back(rt_event); | |||
} | |||
@@ -95,6 +95,7 @@ | |||
#include "graph/passes/memcpy_addr_async_pass.h" | |||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||
#include "graph/passes/parallel_group_pass.h" | |||
#include "graph/passes/buffer_pool_memory_pass.h" | |||
#include "graph/build/label_allocator.h" | |||
#include "graph/utils/tensor_adapter.h" | |||
#include "inc/pass_manager.h" | |||
@@ -2528,6 +2529,12 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); | |||
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | |||
// Process offset and dependency for buffer pool memory assigner. | |||
GE_TIMESTAMP_START(BufferPoolMemoryPass); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
GE_CHK_STATUS_RET(buffer_pool_mem_pass.Run(compute_graph), "Failed to process for buffer pool allocator."); | |||
GE_TIMESTAMP_END(BufferPoolMemoryPass, "BufferPoolMemoryPass::Run."); | |||
// Handle parallel group . | |||
GE_TIMESTAMP_START(ParallelGroup); | |||
ParallelGroupPass parallel_group_pass; | |||
@@ -0,0 +1,574 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/buffer_pool_memory_pass.h" | |||
#include <string> | |||
#include <vector> | |||
#include "graph/common/omg_util.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "common/math/math_util.h" | |||
namespace ge { | |||
namespace { | |||
const size_t kBufferPoolNodeInSize = 1; | |||
const size_t kBufferPoolNodeOutSize = 1; | |||
} // namespace | |||
Status BufferPoolMemoryPass::Run(ComputeGraphPtr graph) { | |||
if (graph == nullptr) { | |||
GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr"); | |||
REPORT_INNER_ERROR("E19999", "Input graph is nullptr"); | |||
return PARAM_INVALID; | |||
} | |||
// The cache prefetching scheme is developed for very large models, which gets the weight data in advance | |||
// and allocates it to a special memory pool. When the large model is dynamic shape, it need to go through | |||
// the executor flow and is not allocated memory statically. This is another development point, so we will | |||
// skip the dynamic shape model processing here. | |||
if (graph->GetParentGraph() != nullptr || graph->GetGraphUnknownFlag()) { | |||
return SUCCESS; | |||
} | |||
if (!IsBufferPoolMemEnable(graph)) { | |||
GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", graph->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
Status ret = graph->TopologicalSorting(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str()); | |||
return ret; | |||
} | |||
ret = CopyOutForMultiUsedOutput(graph); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Copy][Output]Graph:%s.", graph->GetName().c_str()); | |||
return FAILED; | |||
} | |||
ret = GetBufferPoolAndPeerCalcNodes(graph); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Get][BufferPoolNode]Graph:%s.", graph->GetName().c_str()); | |||
return FAILED; | |||
} | |||
if (calc_nodes_.empty()) { | |||
GELOGE(FAILED, "[Check][BufferPoolNode]Graph:%s.", graph->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "All Buffer pool nodes are isolated nodes in graph:%s.", graph->GetName().c_str()); | |||
return FAILED; | |||
} | |||
ret = AllocateAllBufferPoolSpace(); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Alloc][BufferPoolMem]Graph:%s.", graph->GetName().c_str()); | |||
return FAILED; | |||
} | |||
ret = SetResultOfMemoryAndEvent(); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "[Set][Result]Graph:%s.", graph->GetName().c_str()); | |||
return FAILED; | |||
} | |||
ret = graph->TopologicalSorting(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str()); | |||
return ret; | |||
} | |||
return SUCCESS; | |||
} | |||
void BufferPoolMemoryPass::ClearQueue(std::queue<std::pair<std::string, uint32_t>> &q) { | |||
while (!q.empty()) { | |||
q.pop(); | |||
} | |||
} | |||
Status BufferPoolMemoryPass::IsBufferPoolMemEnable(const ComputeGraphPtr &graph) { | |||
for (NodePtr &node : graph->GetAllNodes()) { | |||
auto op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
continue; | |||
} | |||
if (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
Status BufferPoolMemoryPass::CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size, | |||
std::unordered_map<int64_t, int64_t> &calc_total_size) { | |||
auto iter = calc_total_size.find(pool_id); | |||
if (iter == calc_total_size.end()) { | |||
calc_total_size[pool_id] = total_size; | |||
} else { | |||
FMK_INT64_ADDCHECK(calc_total_size[pool_id], total_size); | |||
calc_total_size[pool_id] += total_size; | |||
} | |||
if (calc_total_size[pool_id] > buffer_pool_size) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Size]The memory required at the same is greater than buffer pool size, " | |||
"pool id:%ld, pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]); | |||
REPORT_INNER_ERROR("E19999", "The memory required at the same is greater than buffer pool size, pool id:%ld," | |||
" pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]); | |||
return INTERNAL_ERROR; | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool ¬_change) { | |||
auto pre_node_graph = pre_node->GetOwnerComputeGraph(); | |||
auto curr_node_graph = curr_node->GetOwnerComputeGraph(); | |||
std::string pre_node_stream_label; | |||
(void) AttrUtils::GetStr(pre_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, pre_node_stream_label); | |||
std::string curr_node_stream_label; | |||
(void) AttrUtils::GetStr(curr_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, curr_node_stream_label); | |||
not_change = true; | |||
if ((pre_node_graph == curr_node_graph) && (pre_node_stream_label == pre_node_stream_label)) { | |||
// Same subgraph, including simultaneously in the root graph. | |||
auto ret = ge::GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), curr_node->GetInControlAnchor()); | |||
if (ret != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Add][Edge]Src:%s, dst:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to add ctrl edge from %s to %s.", | |||
pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
not_change = false; | |||
} else if (pre_node_graph->GetParentGraph() == curr_node_graph->GetParentGraph() && | |||
pre_node_graph->GetParentNode() != nullptr && curr_node_graph->GetParentNode() != nullptr) { | |||
// Two nodes are located on different child graphs of different parent nodes. | |||
auto pre_node_parent_op_desc = pre_node_graph->GetParentNode()->GetOpDesc(); | |||
auto curr_node_parent_op_desc = curr_node_graph->GetParentNode()->GetOpDesc(); | |||
GE_CHECK_NOTNULL(pre_node_parent_op_desc); | |||
GE_CHECK_NOTNULL(curr_node_parent_op_desc); | |||
// The parent node dependency is correct to ensure that the child node dependency, | |||
// there is no need to add control edges. | |||
if (pre_node_parent_op_desc->GetId() > curr_node_parent_op_desc->GetId()) { | |||
GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.", | |||
pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.", | |||
pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGI("[Check][Dependency]The two nodes are located in sub graphs of different parent nodes and meet the " | |||
"dependency relationship. pre:%s, curr:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
} else { | |||
GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.", | |||
pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.", | |||
pre_node->GetName().c_str(), curr_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node) { | |||
auto out_anchor = node->GetOutDataAnchor(kBufferPoolNodeOutIndex); | |||
OpDescBuilder op_desc_builder(node->GetName() + "_memcpy_async", MEMCPYASYNC); | |||
auto mem_copy_op = op_desc_builder.AddInput("x", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex)) | |||
.AddOutput("y", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex)) | |||
.Build(); | |||
std::string batch_label; | |||
bool get_attr = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, batch_label); | |||
if (get_attr && !batch_label.empty()) { | |||
(void) AttrUtils::SetStr(mem_copy_op, ATTR_NAME_STREAM_LABEL, batch_label); | |||
} | |||
auto peer_in_anchors = out_anchor->GetPeerInDataAnchors(); | |||
std::vector<InDataAnchorPtr> in_anchors(peer_in_anchors.begin(), peer_in_anchors.end()); | |||
if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(mem_copy_op)) != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "[Insert][Node] Node:%s.", node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::CopyOutForMultiUsedOutput(ComputeGraphPtr &graph) { | |||
bool changed = false; | |||
for (NodePtr &node : graph->GetAllNodes()) { | |||
auto op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
continue; | |||
} | |||
bool use_buffer_pool = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE); | |||
if (use_buffer_pool) { | |||
if ((node->GetInDataNodes().size() == kBufferPoolNodeInSize) && | |||
(node->GetOutDataNodes().size() == kBufferPoolNodeOutSize)) { | |||
continue; | |||
} else if ((node->GetAllInDataAnchors().size() == kBufferPoolNodeInSize) && | |||
(node->GetAllOutDataAnchors().size() == kBufferPoolNodeOutSize)) { | |||
// A prefetching output is used in multiple places. Copy one so that the prefetching node remains | |||
// single input and single output. | |||
if (InsertMemCpyNodeAfter(graph, node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Insert][MemCpy]Node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
changed = true; | |||
GELOGI("[Insert][Node]Insert mem copy node after %s.", node->GetName().c_str()); | |||
} else { | |||
GELOGE(PARAM_INVALID, "[Check][InputOutput]Only support single input and single output, " | |||
"node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Only support single input and single output, node:%s.", node->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
} | |||
} | |||
if (changed) { | |||
Status ret = graph->TopologicalSorting(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph) { | |||
std::unordered_map<std::string, std::unordered_map<int64_t, std::set<NodePtr>>> unique_calc_nodes; | |||
for (const NodePtr &node : graph->GetAllNodes()) { | |||
auto in_data_nodes = node->GetInAllNodes(); | |||
for (NodePtr &in_node : in_data_nodes) { | |||
int64_t buffer_pool_id = 0; | |||
int64_t buffer_pool_size = 0; | |||
bool get_attr = AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id); | |||
get_attr = get_attr && (AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size)); | |||
if (get_attr) { | |||
std::string batch_label; | |||
(void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | |||
peer_buffer_node_item_[batch_label][node].emplace_back(BufferPoolNodeItem(in_node, 0, 0)); | |||
buffer_node_to_calc_[batch_label][in_node] = node; | |||
if (unique_calc_nodes[batch_label][buffer_pool_id].count(node) == 0) { | |||
calc_nodes_[batch_label][buffer_pool_id].emplace_back(node); | |||
unique_calc_nodes[batch_label][buffer_pool_id].insert(node); | |||
} | |||
GELOGI("[Get][BufferNode]Calc node:%s, pool node:%s.", node->GetName().c_str(), in_node->GetName().c_str()); | |||
Status ret = SetBufferPoolSize(batch_label, buffer_pool_id, buffer_pool_size); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Set][BufferPoolSize]Node:%s", in_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to set buffer pool size, something wrong with the info of node:%s", | |||
in_node->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size) { | |||
auto iter = buffer_pool_size_[batch_label].find(id); | |||
if (iter != buffer_pool_size_[batch_label].end() && iter->second != size) { | |||
GELOGE(PARAM_INVALID, "[Check][BufferPoolSize]Get different size with the same id, " | |||
"id:%ld, original size:%ld, this size:%ld.", id, iter->second, size); | |||
REPORT_INNER_ERROR("E19999", "Get different size with the same id, " | |||
"id:%ld, original size:%ld, this size:%ld.", id, iter->second, size); | |||
return PARAM_INVALID; | |||
} | |||
buffer_pool_size_[batch_label][id] = size; | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::AllocateAllBufferPoolSpace() { | |||
for (const auto &iter : calc_nodes_) { | |||
std::string batch_label = iter.first; | |||
Status ret = AllocateSpaceInBatch(calc_nodes_[batch_label], | |||
buffer_pool_size_[batch_label], | |||
buffer_node_to_calc_[batch_label], | |||
peer_buffer_node_item_[batch_label]); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Alloc][InBatch]Batch_label:%s.", batch_label.c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to allocate space in batch, batch_label:%s.", batch_label.c_str()); | |||
return ret; | |||
} | |||
GELOGI("[Alloc][InBatch]Alloc space in batch successfully, batch label:%s.", batch_label.c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::AllocateSpaceInBatch( | |||
const std::map<int64_t, std::vector<NodePtr>> &calc_nodes, | |||
const std::unordered_map<int64_t, int64_t> &buffer_pool_size_map, | |||
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc, | |||
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item) { | |||
for (const auto &calc_node_in_pool : calc_nodes) { | |||
int64_t pool_id = calc_node_in_pool.first; | |||
int64_t buffer_pool_size = buffer_pool_size_map.at(pool_id); | |||
ClearQueue(mem_ctrl_event_); | |||
ClearQueue(stream_ctrl_event_); | |||
BufferPool buffer_pool(pool_id, buffer_pool_size, buffer_node_to_calc); | |||
Status ret = AllocateSpaceInBufferPool(buffer_pool, | |||
calc_node_in_pool.second, | |||
buffer_pool_nodes_item); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Alloc][InBufferPool]Pool id:%ld, pool size:%ld.", pool_id, buffer_pool_size); | |||
REPORT_INNER_ERROR("E19999", "Failed to allocate space in buffer pool, id:%ld, pool size:%ld.", | |||
pool_id, buffer_pool_size); | |||
return ret; | |||
} | |||
GELOGI("[Alloc][InBufferPool]Alloc space in buffer pool successfully, pool id:%ld.", pool_id); | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::AllocateSpaceInBufferPool( | |||
const BufferPool &buffer_pool, | |||
const std::vector<NodePtr> &calc_nodes_in_pool, | |||
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item) { | |||
int64_t pool_id = buffer_pool.pool_id; | |||
int64_t buffer_pool_size = buffer_pool.pool_size; | |||
int64_t next_start = 0; | |||
NodePtr pre_buffer_pool_node = nullptr; | |||
std::queue<BufferPoolNodeItem> node_mem_range_in_pool; | |||
node_mem_range_in_pool.push(BufferPoolMemoryPass::BufferPoolNodeItem(nullptr, 0, buffer_pool_size)); | |||
for (auto &calc_node : calc_nodes_in_pool) { | |||
auto &peer_buffer_node_item = buffer_pool_nodes_item[calc_node]; | |||
std::unordered_map<int64_t, int64_t> calc_total_size; | |||
size_t input_buffer_node_num = 0; | |||
for (auto &node_item : peer_buffer_node_item) { | |||
auto peer_buffer_node = node_item.node; | |||
GE_CHECK_NOTNULL(peer_buffer_node); | |||
int64_t total_size = 0; | |||
++input_buffer_node_num; | |||
Status ret = GetMemorySize(peer_buffer_node, total_size); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Get][MemSize]Node:%s, calc_node:%s.", | |||
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to get memory size, node:%s, calc_node:%s.", | |||
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); | |||
return ret; | |||
} | |||
ret = CheckBufferPoolSize(total_size, pool_id, buffer_pool_size, calc_total_size); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Check][BufferPoolSize]Capacity is not enough for all data, calc_node:%s.", | |||
calc_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Capacity is not enough for all data, calc_node:%s.", | |||
calc_node->GetName().c_str()); | |||
return ret; | |||
} | |||
BufferPoolNodeItem buffer_pool_node_item(peer_buffer_node, calc_node, pre_buffer_pool_node, total_size, | |||
0, 0, (input_buffer_node_num == peer_buffer_node_item.size())); | |||
ret = AllocateSpaceForBufferPoolNode(next_start, buffer_pool, buffer_pool_node_item, node_mem_range_in_pool); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Alloc][ForNode]Pool node:%s, calc_node:%s.", | |||
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to allocate space for buffer pool node:%s, calc_node:%s.", | |||
peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str()); | |||
return ret; | |||
} | |||
pre_buffer_pool_node = peer_buffer_node; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::AllocateSpaceForBufferPoolNode(int64_t &next_start, | |||
const BufferPool buffer_pool, | |||
BufferPoolNodeItem &buffer_pool_node_item, | |||
std::queue<BufferPoolNodeItem> &node_mem_range_in_pool) { | |||
// Get event id must be before FixTheTimingOfDependentNodes | |||
uint32_t logic_event = logic_event_num_; | |||
NodePtr buffer_node = buffer_pool_node_item.node; | |||
NodePtr calc_node = buffer_pool_node_item.out_calc_node; | |||
/// In the scenario where there are multiple PREFETCH operators in the inputs of the calculation operator, | |||
/// the addition of events is optimized to only add events after the last PREFETCH operator. | |||
/// w1 w2 w3 w4 w5 | |||
/// | | | | | | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 xxx | |||
/// \ / \ / \ / | |||
/// \ / \ / \ / | |||
/// \ / \ / \ / | |||
/// node1 node2 node3 | |||
/// | | | | |||
/// | | | | |||
/// --------------- other nodes ------------ | |||
/// | |||
/// The event id of the PREFETCH operator to the calculation operator needs to be generated before | |||
/// FixTheTimingOfDependentNodes, because FixTheTimingOfDependentNodes may add a new id to stream_ctrl_event_, | |||
/// and this id cannot be reused until the next PREFETCH operator in the sequence. | |||
if (buffer_pool_node_item.is_last_input) { | |||
logic_event = GenerateEventId(buffer_node->GetName(), stream_ctrl_event_); | |||
node_event_multiplexing_[buffer_node].push_back(string("SendTo;" + calc_node->GetName() + | |||
";" + std::to_string(logic_event))); | |||
mem_ctrl_event_.push(std::make_pair(calc_node->GetName(), logic_event)); | |||
} | |||
NodePtr dependent_calc_node = GetOffsetAndDependency(next_start, buffer_pool_node_item.total_size, | |||
buffer_pool.pool_size, | |||
buffer_pool.buffer_node_to_calc, | |||
node_mem_range_in_pool); | |||
if (dependent_calc_node != nullptr) { | |||
Status ret = FixTheTimingOfDependentNodes(dependent_calc_node, buffer_node); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Fix][Timing]Pool_id:%ld, pool node:%s, dependent node:%s.", | |||
buffer_pool.pool_id, buffer_node->GetName().c_str(), dependent_calc_node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to fix timing, pool_id:%ld, pool node:%s, dependent node:%s.", | |||
buffer_pool.pool_id, buffer_node->GetName().c_str(), | |||
dependent_calc_node->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
buffer_pool_node_item.offset_start = next_start; | |||
buffer_node_logical_offset_[buffer_node].push_back(buffer_pool_node_item.total_size); | |||
buffer_node_logical_offset_[buffer_node].push_back(next_start); | |||
FMK_INT64_ADDCHECK(next_start, buffer_pool_node_item.total_size); | |||
next_start += buffer_pool_node_item.total_size; | |||
buffer_pool_node_item.offset_end = next_start; | |||
node_mem_range_in_pool.push(buffer_pool_node_item); | |||
if (buffer_pool_node_item.pre_buffer_pool_node != nullptr) { | |||
bool not_change = true; | |||
auto ret = TryToFixNodeOrder(buffer_pool_node_item.pre_buffer_pool_node, buffer_node, not_change); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Fix][BufferPoolNodeOrder]Pre node:%s, curr node:%s.", | |||
buffer_pool_node_item.pre_buffer_pool_node->GetName().c_str(), buffer_node->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
GELOGI("[Alloc][ForNode]Buffer pool node %s send to %s, offset start:%ld, send event id:%u.", | |||
buffer_node->GetName().c_str(), calc_node->GetName().c_str(), | |||
buffer_pool_node_item.offset_start, logic_event); | |||
return SUCCESS; | |||
} | |||
/// When generating the event ID, determine whether the name of the queue head node is the same as the name of | |||
/// the operator, in order to handle such scenarios: | |||
/// w1 w2 w3 w4 w5 | |||
/// | | | | | | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// | | | | | | |||
/// node1 node2 node3 node4 node5 | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |____w1_____|__| | |||
/// | |||
/// |____w2_____|__| | |||
/// | |||
/// |____w3_____|__| | |||
/// | |||
/// |______w4______| | |||
/// | |||
/// |______w5______| | |||
/// | |||
/// In this scenario, prefetch2 depends on node1. If the dependency is handled by adding an event of node1 to prefetch2, | |||
/// the id sent by prefetch2 will be the same as the id it receives.Although Runtime supports this through WaitReset, | |||
/// we consider this a dangerous operation and avoid it. | |||
uint32_t BufferPoolMemoryPass::GenerateEventId(const std::string &node_name, | |||
std::queue<std::pair<std::string, uint32_t>> &event_queue) { | |||
uint32_t logic_event = logic_event_num_; | |||
if (!event_queue.empty()) { | |||
auto item = event_queue.front(); | |||
if (item.first != node_name) { | |||
logic_event = item.second; | |||
event_queue.pop(); | |||
return logic_event; | |||
} | |||
} | |||
++logic_event_num_; | |||
return logic_event; | |||
} | |||
NodePtr BufferPoolMemoryPass::GetOffsetAndDependency(int64_t &next_start, | |||
int64_t total_mem_size, | |||
int64_t buffer_pool_size, | |||
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc, | |||
std::queue<BufferPoolMemoryPass::BufferPoolNodeItem> &nodes_in_buffer) { | |||
// The buffer pool can no longer fit this Tensor and needs to turn back. | |||
if (next_start + total_mem_size > buffer_pool_size) { | |||
next_start = 0; | |||
if (!nodes_in_buffer.empty()) { | |||
// Take up the rest of the space at the end, | |||
nodes_in_buffer.back().offset_end = buffer_pool_size; | |||
// Pop the first tensor memory in the previous round of the previous round. | |||
nodes_in_buffer.pop(); | |||
} | |||
while (!nodes_in_buffer.empty()) { | |||
auto node_item = nodes_in_buffer.front(); | |||
// Go to the begin of previous round. | |||
if (node_item.offset_start == 0) { | |||
break; | |||
} | |||
nodes_in_buffer.pop(); | |||
} | |||
} | |||
while (!nodes_in_buffer.empty()) { | |||
auto node_item = nodes_in_buffer.front(); | |||
if (next_start + total_mem_size <= node_item.offset_end) { | |||
auto pool_node = node_item.node; | |||
if (pool_node == nullptr) { | |||
return nullptr; | |||
} | |||
auto output_calc = buffer_node_to_calc.find(pool_node); | |||
if (output_calc != buffer_node_to_calc.end()) { | |||
return output_calc->second; | |||
} | |||
return nullptr; | |||
} | |||
nodes_in_buffer.pop(); | |||
} | |||
return nullptr; | |||
} | |||
Status BufferPoolMemoryPass::FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node) { | |||
// The previous process ensures that all pointers are not null. | |||
bool not_change = false; | |||
Status ret = TryToFixNodeOrder(dependent_calc_node, curr_pool_node, not_change); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Fix][NodeOrder]Src:%s, dst:%s.", | |||
dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str()); | |||
return ret; | |||
} | |||
if (not_change) { | |||
return SUCCESS; | |||
} | |||
uint32_t logic_event = GenerateEventId(dependent_calc_node->GetName(), mem_ctrl_event_); | |||
node_event_multiplexing_[curr_pool_node].push_back(string("RecvFrom;" + dependent_calc_node->GetName() + | |||
";" + std::to_string(logic_event))); | |||
stream_ctrl_event_.push(std::make_pair(curr_pool_node->GetName(), logic_event)); | |||
GELOGI("[Fix][Timing]Add ctrl edge for buffer pool memory from %s to %s, buffer pool node recv event:%u.", | |||
dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str(), logic_event); | |||
return SUCCESS; | |||
} | |||
Status BufferPoolMemoryPass::SetResultOfMemoryAndEvent() { | |||
for (auto &iter : node_event_multiplexing_) { | |||
auto node = iter.first; | |||
GE_CHECK_NOTNULL(node); | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
bool ret = AttrUtils::SetListStr(op_desc, ATTR_NAME_EVENT_MULTIPLEXING, iter.second); | |||
if (!ret) { | |||
GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to set event reuse info, node:%s.", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
auto offset_iter = buffer_node_logical_offset_.find(node); | |||
if (offset_iter == buffer_node_logical_offset_.end()) { | |||
GELOGE(INTERNAL_ERROR, "[Get][LogicalOffset]Node:%s.", node->GetName().c_str()); | |||
REPORT_INNER_ERROR("E19999", "Failed to get logical offset and size, node:%s.", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
ret = AttrUtils::SetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, offset_iter->second); | |||
if (!ret) { | |||
GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to set node memory offset and size, node:%s.", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -0,0 +1,136 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ | |||
#define GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ | |||
#include <queue> | |||
#include "graph/graph.h" | |||
#include "inc/graph_pass.h" | |||
namespace ge { | |||
class BufferPoolMemoryPass : public GraphPass { | |||
public: | |||
explicit BufferPoolMemoryPass() : logic_event_num_(0) {} | |||
~BufferPoolMemoryPass() override = default; | |||
struct BufferPool { | |||
int64_t pool_id = 0; | |||
int64_t pool_size = 0; | |||
std::unordered_map<NodePtr, NodePtr> buffer_node_to_calc; | |||
BufferPool(int64_t id, int64_t size, const std::unordered_map<NodePtr, NodePtr> &node_map) | |||
: pool_id(id), pool_size(size), buffer_node_to_calc(node_map) {} | |||
}; | |||
struct BufferPoolNodeItem { | |||
NodePtr node = nullptr; | |||
NodePtr out_calc_node = nullptr; | |||
NodePtr pre_buffer_pool_node = nullptr; | |||
int64_t total_size = 0; | |||
int64_t offset_start = 0; | |||
int64_t offset_end = 0; | |||
bool is_last_input = true; | |||
BufferPoolNodeItem(const NodePtr &buffer_n, const NodePtr &calc_n, const NodePtr &pre_buffer_n, | |||
int64_t size, int64_t start, int64_t end, bool last) | |||
: node(std::move(buffer_n)), | |||
out_calc_node(std::move(calc_n)), | |||
pre_buffer_pool_node(std::move(pre_buffer_n)), | |||
total_size(size), | |||
offset_start(start), | |||
offset_end(end), | |||
is_last_input(last) {} | |||
BufferPoolNodeItem(const NodePtr &buffer_n, int64_t start, int64_t end) | |||
: node(std::move(buffer_n)), | |||
out_calc_node(nullptr), | |||
pre_buffer_pool_node(nullptr), | |||
total_size(0), | |||
offset_start(start), | |||
offset_end(end), | |||
is_last_input(true) {} | |||
}; | |||
Status Run(ComputeGraphPtr graph) override; | |||
private: | |||
static void ClearQueue(std::queue<std::pair<std::string, uint32_t>> &q); | |||
static Status IsBufferPoolMemEnable(const ComputeGraphPtr &graph); | |||
static Status CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size, | |||
std::unordered_map<int64_t, int64_t> &calc_total_size); | |||
static Status TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool ¬_change); | |||
Status InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node); | |||
Status CopyOutForMultiUsedOutput(ComputeGraphPtr &graph); | |||
Status GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph); | |||
Status SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size); | |||
Status AllocateAllBufferPoolSpace(); | |||
Status AllocateSpaceInBatch(const std::map<int64_t, std::vector<NodePtr>> &calc_nodes, | |||
const std::unordered_map<int64_t, int64_t> &buffer_pool_size_map, | |||
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc, | |||
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item); | |||
Status AllocateSpaceInBufferPool(const BufferPool &buffer_pool, | |||
const std::vector<NodePtr> &calc_nodes_in_pool, | |||
std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item); | |||
Status AllocateSpaceForBufferPoolNode(int64_t &next_start, | |||
const BufferPool buffer_pool, | |||
BufferPoolNodeItem &buffer_pool_node_item, | |||
std::queue<BufferPoolNodeItem> &node_mem_range_in_pool); | |||
NodePtr GetOffsetAndDependency(int64_t &next_start, | |||
int64_t total_mem_size, | |||
int64_t buffer_pool_size, | |||
const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc, | |||
std::queue<BufferPoolNodeItem> &nodes_in_buffer); | |||
Status FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node); | |||
uint32_t GenerateEventId(const std::string &node_name, std::queue<std::pair<std::string, uint32_t>> &event_queue); | |||
Status SetResultOfMemoryAndEvent(); | |||
// Use map to ensure that each visit is in the order of batch label and pool id | |||
std::map<std::string, std::map<int64_t, std::vector<NodePtr>>> calc_nodes_; | |||
std::unordered_map<std::string, std::unordered_map<NodePtr, NodePtr>> buffer_node_to_calc_; | |||
std::unordered_map<std::string, std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>>> peer_buffer_node_item_; | |||
std::unordered_map<std::string, std::unordered_map<int64_t, int64_t>> buffer_pool_size_; | |||
uint32_t logic_event_num_; | |||
std::queue<std::pair<std::string, uint32_t>> mem_ctrl_event_; | |||
std::queue<std::pair<std::string, uint32_t>> stream_ctrl_event_; | |||
std::unordered_map<NodePtr, std::vector<std::string>> node_event_multiplexing_; | |||
std::unordered_map<NodePtr, std::vector<int64_t>> buffer_node_logical_offset_; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ |
@@ -43,6 +43,11 @@ rtError_t rtEventCreate(rtEvent_t *event) { | |||
*event = new int[EVENT_LENTH]; | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag) { | |||
return rtEventCreate(event); | |||
} | |||
rtError_t rtEventRecord(rtEvent_t event, rtStream_t stream) { return RT_ERROR_NONE; } | |||
rtError_t rtEventSynchronize(rtEvent_t event) { return RT_ERROR_NONE; } | |||
@@ -276,6 +276,7 @@ set(COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" | |||
"${GE_CODE_DIR}/ge/model/ge_model.cc" | |||
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | |||
"${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" | |||
@@ -323,6 +324,7 @@ set(COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/build/memory/block_mem_assigner.cc" | |||
"${GE_CODE_DIR}/ge/graph/build/memory/binary_block_mem_assigner.cc" | |||
"${GE_CODE_DIR}/ge/graph/build/memory/max_block_mem_assigner.cc" | |||
"${GE_CODE_DIR}/ge/graph/build/memory/buffer_pool_mem_assigner.cc" | |||
"${GE_CODE_DIR}/ge/graph/manager/graph_mem_allocator.cc" | |||
"${GE_CODE_DIR}/ge/graph/manager/graph_var_manager.cc" | |||
"${GE_CODE_DIR}/ge/analyzer/analyzer.cc" | |||
@@ -627,6 +629,7 @@ set(SINGLE_OP_SRC_FILES | |||
# test files | |||
set(COMMON_TEST_FILES | |||
"graph/passes/graph_builder_utils.cc" | |||
"graph/utils/buffer_pool_graph_builder.cc" | |||
"test.cc" | |||
) | |||
@@ -703,6 +706,7 @@ set(PASS_TEST_FILES | |||
"graph/passes/link_gen_mask_nodes_pass_unittest.cc" | |||
"graph/passes/transpose_transdata_pass_unittest.cc" | |||
"graph/passes/parallel_group_pass_unittest.cc" | |||
"graph/passes/buffer_pool_memory_pass_unittest.cc" | |||
) | |||
set(KERNEL_TEST_FILES | |||
@@ -771,6 +775,7 @@ set(MULTI_PARTS_TEST_FILES | |||
"graph/build/model_builder_unittest.cc" | |||
"graph/build/mem_assigner_unittest.cc" | |||
"graph/build/task_generator_unittest.cc" | |||
"graph/build/buffer_pool_mem_assigner_unittest.cc" | |||
"graph/preprocess/graph_preprocess_unittest.cc" | |||
"graph/manager/hcom_util_unittest.cc" | |||
"graph/manager/graph_caching_allocator_unittest.cc" | |||
@@ -0,0 +1,607 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "../utils/buffer_pool_graph_builder.h" | |||
#include "graph/passes/buffer_pool_memory_pass.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/build/memory/buffer_pool_mem_assigner.h" | |||
#include "graph/build/memory/graph_mem_assigner.h" | |||
#include "graph/build/stream_allocator.h" | |||
#undef protected | |||
#undef private | |||
namespace ge { | |||
namespace { | |||
const int64_t kMemoryTypeHBM = static_cast<int64_t>(RT_MEMORY_HBM); | |||
const int64_t kMemoryTypeP2P = static_cast<int64_t>(RT_MEMORY_P2P_HBM); | |||
const int64_t kMemoryTypeDDR = static_cast<int64_t>(RT_MEMORY_DDR); | |||
const size_t kOffsetHBM = 10240; | |||
const size_t kOffsetP2P = 20480; | |||
const size_t kOffsetDDR = 30720; | |||
const int64_t kMemAlignSize = 512; | |||
int64_t AlignMemSize(int64_t mem_size, int64_t align_size = kMemAlignSize) { | |||
int64_t tmp = (mem_size + align_size - 1) / align_size * align_size; | |||
return tmp; | |||
} | |||
int64_t AlignOutputMemSize(int64_t mem_size) { | |||
int64_t tmp = (mem_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize; | |||
// hccl need alignment | |||
tmp = kMemAlignSize + tmp + kMemAlignSize; | |||
return tmp; | |||
} | |||
} // namespace | |||
class UtestBufferPoolMemAssignerTest : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_normal_assign_success) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, | |||
{kMemoryTypeP2P, kOffsetP2P}}; | |||
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize); | |||
std::vector<int64_t> expect_offset = {(offset_base + 0), | |||
(offset_base + AlignOutputMemSize(500)), | |||
(offset_base + (AlignOutputMemSize(500) * 2)), | |||
(offset_base + 0), | |||
(offset_base + AlignOutputMemSize(1024))}; | |||
BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); | |||
ret = buffer_pool_mem_assigner.Assign(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + | |||
AlignMemSize(5600, kMemAlignSize) + kMemAlignSize); | |||
{ | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_normal_graph_with_multi_buffer_pool_assign_success) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraphWithMultiBufferPool"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraphWithMultiBufferPool(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, | |||
{kMemoryTypeP2P, kOffsetP2P}}; | |||
int64_t offset_base_0 = static_cast<int64_t>(kOffsetHBM + kMemAlignSize); | |||
int64_t offset_base_1 = static_cast<int64_t>(kOffsetHBM + kMemAlignSize) + | |||
AlignMemSize(5000, kMemAlignSize) + kMemAlignSize; | |||
std::vector<int64_t> expect_offset = {(offset_base_0 + 0), | |||
(offset_base_1 + 0), | |||
(offset_base_0 + AlignOutputMemSize(500)), | |||
(offset_base_0 + 0), | |||
(offset_base_1 + AlignOutputMemSize(500))}; | |||
BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); | |||
ret = buffer_pool_mem_assigner.Assign(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base_1 + | |||
AlignMemSize(5000, kMemAlignSize) + kMemAlignSize); | |||
{ | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success) { | |||
ut::BufferPoolGraphBuilder builder("SerialGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildSerialGraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, | |||
{kMemoryTypeP2P, kOffsetP2P}}; | |||
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize); | |||
std::vector<int64_t> expect_offset = {offset_base, offset_base, offset_base, offset_base, offset_base}; | |||
BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); | |||
ret = buffer_pool_mem_assigner.Assign(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + | |||
AlignMemSize(2048, kMemAlignSize) + kMemAlignSize); | |||
{ | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) { | |||
ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); | |||
ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, | |||
{kMemoryTypeP2P, kOffsetP2P}}; | |||
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize); | |||
std::vector<int64_t> expect_offset = {(offset_base + 0), | |||
(offset_base + AlignOutputMemSize(500)), | |||
(offset_base + (AlignOutputMemSize(500) * 2)), | |||
(offset_base + 0), | |||
(offset_base + AlignOutputMemSize(1024))}; | |||
BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); | |||
ret = buffer_pool_mem_assigner.Assign(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + | |||
AlignMemSize(5600, kMemAlignSize) + kMemAlignSize); | |||
std::map<std::string, NodePtr> all_nodes; | |||
for (auto node : graph->GetAllNodes()) { | |||
EXPECT_NE(node, nullptr); | |||
all_nodes[node->GetName()] = node; | |||
} | |||
{ | |||
auto prefetch = all_nodes.at("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); | |||
} | |||
{ | |||
auto prefetch = all_nodes.at("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); | |||
} | |||
{ | |||
auto prefetch = all_nodes.at("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); | |||
} | |||
{ | |||
auto prefetch = all_nodes.at("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); | |||
} | |||
{ | |||
auto prefetch = all_nodes.at("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_graph_with_multi_batch_assign_success) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithMultiBatch"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiBatch(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<int64_t, size_t> mem_type_to_offset = {{kMemoryTypeHBM, kOffsetHBM}, | |||
{kMemoryTypeP2P, kOffsetP2P}}; | |||
int64_t offset_base = static_cast<int64_t>(kOffsetHBM + kMemAlignSize); | |||
std::vector<int64_t> expect_offset = {(offset_base + 0), | |||
(offset_base + AlignOutputMemSize(500)), | |||
(offset_base + (AlignOutputMemSize(500) * 2)), | |||
(offset_base + 0), | |||
(offset_base + AlignOutputMemSize(1024))}; | |||
BufferPoolMemAssigner buffer_pool_mem_assigner(graph, mem_type_to_offset); | |||
ret = buffer_pool_mem_assigner.Assign(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(buffer_pool_mem_assigner.GetMemOffset(), offset_base + | |||
AlignMemSize(5600, kMemAlignSize) + kMemAlignSize); | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_128/prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_128/prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_128/prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_128/prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_128/prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(0)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(1)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(2)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(3)); | |||
} | |||
{ | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> output_offset = prefetch->GetOpDesc()->GetOutputOffset(); | |||
EXPECT_EQ(output_offset.size(), 1); | |||
EXPECT_EQ(output_offset.at(0), expect_offset.at(4)); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, test_AssignBufferPoolMemory_success) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<int64_t, MemoryOffset> memory_offset = {{kMemoryTypeHBM, MemoryOffset(RT_MEMORY_HBM, kOffsetHBM)}, | |||
{kMemoryTypeP2P, MemoryOffset(RT_MEMORY_P2P_HBM, kOffsetP2P)}}; | |||
GraphMemoryAssigner graph_memory_assigner(graph); | |||
graph_memory_assigner.memory_offset_ = memory_offset; | |||
ret = graph_memory_assigner.AssignBufferPoolMemory(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, test_AssignBufferPoolMemory_fail) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
std::map<int64_t, MemoryOffset> memory_offset = {{kMemoryTypeHBM, MemoryOffset(RT_MEMORY_HBM, kOffsetHBM)}, | |||
{kMemoryTypeP2P, MemoryOffset(RT_MEMORY_P2P_HBM, kOffsetP2P)}}; | |||
{ | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
std::vector<int64_t> type_list = {static_cast<int64_t>(RT_MEMORY_P2P_HBM)}; | |||
bool set_attr = ge::AttrUtils::SetListInt(prefetch->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list); | |||
EXPECT_EQ(set_attr, true); | |||
GraphMemoryAssigner graph_memory_assigner(graph); | |||
graph_memory_assigner.memory_offset_ = memory_offset; | |||
Status ret = graph_memory_assigner.AssignBufferPoolMemory(); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
{ | |||
std::vector<std::string> node_list = {"prefetch1", "prefetch2", "prefetch3", "prefetch4", "prefetch5"}; | |||
std::vector<int64_t> type_list = {static_cast<int64_t>(RT_MEMORY_L1)}; | |||
for (auto &node_name : node_list) { | |||
auto prefetch = graph->FindNode(node_name); | |||
EXPECT_NE(prefetch, nullptr); | |||
EXPECT_NE(prefetch->GetOpDesc(), nullptr); | |||
bool set_attr = ge::AttrUtils::SetListInt(prefetch->GetOpDesc(), ATTR_NAME_OUTPUT_MEM_TYPE_LIST, type_list); | |||
EXPECT_EQ(set_attr, true); | |||
} | |||
GraphMemoryAssigner graph_memory_assigner(graph); | |||
graph_memory_assigner.memory_offset_ = memory_offset; | |||
Status ret = graph_memory_assigner.AssignBufferPoolMemory(); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, test_RefreshEventsWithReuse_success) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<std::string, NodePtr> all_nodes; | |||
for (auto node : graph->GetAllNodes()) { | |||
EXPECT_NE(node, nullptr); | |||
all_nodes[node->GetName()] = node; | |||
} | |||
Graph2SubGraphInfoList sub_graphs; | |||
StreamAllocator stream_allocator(graph, sub_graphs); | |||
stream_allocator.event_num_ = 65520; | |||
// stream ctrl event | |||
stream_allocator.AddSendEventId(all_nodes.at("prefetch1"), 30); | |||
stream_allocator.AddRecvEventId(all_nodes.at("add1"), 30); | |||
stream_allocator.AddSendEventId(all_nodes.at("prefetch2"), 31); | |||
stream_allocator.AddRecvEventId(all_nodes.at("add2"), 31); | |||
stream_allocator.AddSendEventId(all_nodes.at("prefetch3"), 32); | |||
stream_allocator.AddRecvEventId(all_nodes.at("add3"), 32); | |||
stream_allocator.AddSendEventId(all_nodes.at("prefetch4"), 33); | |||
stream_allocator.AddRecvEventId(all_nodes.at("add4"), 33); | |||
stream_allocator.AddSendEventId(all_nodes.at("add2"), 34); | |||
stream_allocator.AddRecvEventId(all_nodes.at("prefetch4"), 34); | |||
stream_allocator.AddSendEventId(all_nodes.at("prefetch5"), 35); | |||
stream_allocator.AddRecvEventId(all_nodes.at("add5"), 35); | |||
stream_allocator.AddSendEventId(all_nodes.at("add3"), 36); | |||
stream_allocator.AddRecvEventId(all_nodes.at("prefetch5"), 36); | |||
// other event | |||
stream_allocator.AddSendEventId(all_nodes.at("prefetch1"), 37); | |||
stream_allocator.AddRecvEventId(all_nodes.at("add5"), 37); | |||
ret = stream_allocator.RefreshEventsWithReuse(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ((stream_allocator.node_to_send_events_.at(all_nodes.at("prefetch1"))).size(), 2); | |||
EXPECT_EQ((stream_allocator.node_to_send_events_.at(all_nodes.at("prefetch5"))).size(), 1); | |||
EXPECT_EQ((stream_allocator.node_to_recv_events_.at(all_nodes.at("prefetch5"))).size(), 1); | |||
EXPECT_EQ((stream_allocator.node_to_recv_events_.at(all_nodes.at("add5"))).size(), 2); | |||
EXPECT_EQ(stream_allocator.event_num_, 5); | |||
} | |||
TEST_F(UtestBufferPoolMemAssignerTest, test_RefreshEventsWithReuse_fail) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
std::map<std::string, NodePtr> all_nodes; | |||
for (auto node : graph->GetAllNodes()) { | |||
EXPECT_NE(node, nullptr); | |||
all_nodes[node->GetName()] = node; | |||
} | |||
std::vector<std::vector<std::string>> event_info = {{"SendTo;add1;0"}, | |||
{"SendTo;add2;1"}, | |||
{"SendTo;add3;2"}, | |||
{"SendTo;add4;3", "RecvFrom;add2;0"}, | |||
{"SendTo;add5;0", "RecvFrom;add3;1"}}; | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch2")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[1]); | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch3")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[2]); | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch4")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[3]); | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch5")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[4]); | |||
Graph2SubGraphInfoList sub_graphs; | |||
StreamAllocator stream_allocator(graph, sub_graphs); | |||
stream_allocator.event_num_ = 65520; | |||
// Item num of raw event info is invalid | |||
event_info[0][0] = "SendTo;add1;0;1"; | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); | |||
Status ret = stream_allocator.RefreshEventsWithReuse(); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
// Event id is invalid argument | |||
event_info[0][0] = "SendTo;add1;event_id"; | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); | |||
ret = stream_allocator.RefreshEventsWithReuse(); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
// Event id is out of range | |||
event_info[0][0] = "SendTo;add1;666666666666666666666666666666666666666"; | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); | |||
ret = stream_allocator.RefreshEventsWithReuse(); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
// Event id is negative | |||
event_info[0][0] = "SendTo;add1;-2"; | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); | |||
ret = stream_allocator.RefreshEventsWithReuse(); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
// Key word is not supported | |||
event_info[0][0] = "SendToKey;add1;2"; | |||
(void) AttrUtils::SetListStr(all_nodes.at("prefetch1")->GetOpDesc(), ATTR_NAME_EVENT_MULTIPLEXING, event_info[0]); | |||
ret = stream_allocator.RefreshEventsWithReuse(); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
} | |||
} // namespace ge | |||
@@ -0,0 +1,591 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/types.h" | |||
#include "graph/manager/graph_var_manager.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "inc/pass_manager.h" | |||
#include "graph_builder_utils.h" | |||
#include "../utils/buffer_pool_graph_builder.h" | |||
#include "graph/passes/buffer_pool_memory_pass.h" | |||
namespace ge { | |||
class UtestBufferPoolMemoryPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_normal_success_test) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add4;3"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add5;0"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;1"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_normal_graph_with_multi_buffer_pool_success_test) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraphWithMultiBufferPool"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraphWithMultiBufferPool(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add2;3"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add3;1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add4;2"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;0"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add5;4"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_contain_one_node_success_test) { | |||
ut::BufferPoolGraphBuilder builder("SerialGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildSerialGraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add1;2"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add4;0"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;1"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add5;1"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add4;2"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add4"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, calc_node_with_multi_buffer_pool_input_success_test) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithMultiPrefetch"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiPrefetch(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 0); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 0); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add1;2"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add2;0"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add2"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_in_different_subgraph_success_test) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithSubgraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithSubgraph(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<std::string, NodePtr> all_nodes; | |||
for (auto node : graph->GetAllNodes()) { | |||
EXPECT_NE(node, nullptr); | |||
all_nodes[node->GetName()] = node; | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add4;3"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 0); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add5;4"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 1); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch4"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_in_different_subgraph_with_inner_dependency_success_test) { | |||
ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); | |||
ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
std::map<std::string, NodePtr> all_nodes; | |||
for (auto node : graph->GetAllNodes()) { | |||
EXPECT_NE(node, nullptr); | |||
all_nodes[node->GetName()] = node; | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add1;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add2;1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add3;2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add4;3"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 1); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch3"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = all_nodes.at("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add5;4"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;add3;0"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "add3"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_with_batch_label_success_test) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithMultiBatch"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiBatch(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add1;4"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add2;5"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add3;6"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add4;7"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;batch_label_256/add2;4"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "batch_label_256/add2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("batch_label_256/prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;batch_label_256/add5;4"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;batch_label_256/add3;5"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "batch_label_256/add3"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_node_has_multi_output_success_test) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithMultiOutputPrefetch"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiOutputPrefetch(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch1"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;prefetch1_memcpy_async;0"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch2"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;prefetch2_memcpy_async;1"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 1); | |||
EXPECT_EQ(event_info.at(0), "SendTo;prefetch3_memcpy_async;2"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch4"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;prefetch4_memcpy_async;3"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;prefetch2_memcpy_async;0"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch2_memcpy_async"); | |||
} | |||
{ | |||
std::vector<std::string> event_info; | |||
auto prefetch = graph->FindNode("prefetch5"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::GetListStr(prefetch->GetOpDesc(), "_event_multiplexing", event_info); | |||
EXPECT_EQ(event_info.size(), 2); | |||
EXPECT_EQ(event_info.at(0), "SendTo;add5;0"); | |||
EXPECT_EQ(event_info.at(1), "RecvFrom;prefetch3_memcpy_async;1"); | |||
auto in_ctrl_nodes = prefetch->GetInControlNodes(); | |||
EXPECT_EQ(in_ctrl_nodes.size(), 2); | |||
EXPECT_EQ(in_ctrl_nodes.at(0)->GetName(), "prefetch3_memcpy_async"); | |||
} | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_has_different_size_fail_test) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
const int64_t dummy_size = 256; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
(void) AttrUtils::SetInt(prefetch->GetOpDesc(), "_buffer_pool_size", dummy_size); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_size_is_not_enough_fail_test) { | |||
ut::BufferPoolGraphBuilder builder("NormalGraph"); | |||
ge::ComputeGraphPtr graph = builder.BuildNormalGraph(); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
builder.SetPrefetchNodeInfo(prefetch, buffer_pool_id, buffer_pool_size, {buffer_pool_size + 512}); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_size_is_not_enough_for_multi_fail_test) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithMultiPrefetch"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiPrefetch(); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
auto prefetch = graph->FindNode("prefetch3"); | |||
EXPECT_NE(prefetch, nullptr); | |||
builder.SetPrefetchNodeInfo(prefetch, buffer_pool_id, buffer_pool_size, {buffer_pool_size}); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
TEST_F(UtestBufferPoolMemoryPass, buffer_pool_node_has_multi_input_output_fail_test) { | |||
ut::BufferPoolGraphBuilder builder("GraphWithMultiInputOutputPrefetch"); | |||
ge::ComputeGraphPtr graph = builder.BuildGraphWithMultiInputOutputPrefetch(); | |||
BufferPoolMemoryPass buffer_pool_mem_pass; | |||
Status ret = buffer_pool_mem_pass.Run(graph); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,978 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "buffer_pool_graph_builder.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
namespace ge { | |||
namespace ut { | |||
BufferPoolGraphBuilder::BufferPoolGraphBuilder(const std::string &name) { | |||
graph_name_ = name; | |||
} | |||
BufferPoolGraphBuilder::InnerGraphBuilder::InnerGraphBuilder(const std::string &name) { | |||
graph_ = std::make_shared<ComputeGraph>(name); | |||
EXPECT_NE(graph_, nullptr); | |||
} | |||
NodePtr BufferPoolGraphBuilder::InnerGraphBuilder::AddNode(const std::string &name, const std::string &type, | |||
int in_cnt, int out_cnt, | |||
Format format, DataType data_type, | |||
std::vector<int64_t> shape) { | |||
auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||
EXPECT_NE(tensor_desc, nullptr); | |||
tensor_desc->SetShape(GeShape(std::move(shape))); | |||
tensor_desc->SetFormat(format); | |||
tensor_desc->SetDataType(data_type); | |||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||
EXPECT_NE(op_desc, nullptr); | |||
for (int i = 0; i < in_cnt; ++i) { | |||
op_desc->AddInputDesc(tensor_desc->Clone()); | |||
} | |||
for (int i = 0; i < out_cnt; ++i) { | |||
op_desc->AddOutputDesc(tensor_desc->Clone()); | |||
} | |||
return graph_->AddNode(op_desc); | |||
} | |||
void BufferPoolGraphBuilder::InnerGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, | |||
NodePtr &dst_node, int dst_idx) { | |||
EXPECT_NE(src_node, nullptr); | |||
EXPECT_NE(dst_node, nullptr); | |||
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||
} | |||
void BufferPoolGraphBuilder::InnerGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||
EXPECT_NE(src_node, nullptr); | |||
EXPECT_NE(dst_node, nullptr); | |||
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||
} | |||
void BufferPoolGraphBuilder::SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, | |||
const std::string &batch_label) { | |||
EXPECT_NE(node, nullptr); | |||
(void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, pool_id); | |||
(void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, pool_size); | |||
if (!batch_label.empty()) { | |||
(void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | |||
} | |||
} | |||
void BufferPoolGraphBuilder::SetBatchLabel(NodePtr &node, const std::string &batch_label) { | |||
EXPECT_NE(node, nullptr); | |||
(void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | |||
} | |||
void BufferPoolGraphBuilder::SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size) { | |||
EXPECT_NE(node, nullptr); | |||
EXPECT_NE(node->GetOpDesc(), nullptr); | |||
size_t output_size = node->GetOpDesc()->GetOutputsSize(); | |||
EXPECT_EQ(output_size, mem_size.size()); | |||
for (size_t i = 0; i < output_size; ++i) { | |||
auto output_op_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||
ge::TensorUtils::SetSize(*output_op_desc, mem_size[i]); | |||
} | |||
} | |||
void BufferPoolGraphBuilder::SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes) { | |||
EXPECT_NE(node, nullptr); | |||
EXPECT_NE(node->GetOpDesc(), nullptr); | |||
node->GetOpDesc()->SetWorkspaceBytes(ws_bytes); | |||
} | |||
void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size, | |||
const std::vector<int64_t> &mem_size, | |||
const std::vector<int64_t> &ws_bytes, | |||
const std::string &batch_label) { | |||
SetBufferPool(node, pool_id, pool_size, batch_label); | |||
SetOutputMemSize(node, mem_size); | |||
SetWorkSpaceMemSize(node, ws_bytes); | |||
} | |||
/// | |||
/// Normal graph | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||
auto add2 = builder.AddNode("add2", ADD, 2, 1); | |||
auto add3 = builder.AddNode("add3", ADD, 2, 1); | |||
auto add4 = builder.AddNode("add4", ADD, 2, 1); | |||
auto add5 = builder.AddNode("add5", ADD, 2, 1); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(add1, 0, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(add2, 0, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(add3, 0, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(add4, 0, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add5, 0, net_output, 0); | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
/// | |||
/// Normal graph with multi buffer pool | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// (pool0) (pool1) (pool0) (pool0) (pool1) | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w3__|_________| | |||
/// |_____w4_____|___________| | |||
/// | |||
/// |___w2__|_____w5___|_____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
const int64_t buffer_pool_id_0 = 0; | |||
const int64_t buffer_pool_id_1 = 1; | |||
const int64_t buffer_pool_size = 5000; | |||
auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id_0, buffer_pool_size, {500}); | |||
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id_1, buffer_pool_size, {500}); | |||
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id_0, buffer_pool_size, {500}); | |||
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id_0, buffer_pool_size, {1024}); | |||
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id_1, buffer_pool_size, {1024}); | |||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||
auto add2 = builder.AddNode("add2", ADD, 2, 1); | |||
auto add3 = builder.AddNode("add3", ADD, 2, 1); | |||
auto add4 = builder.AddNode("add4", ADD, 2, 1); | |||
auto add5 = builder.AddNode("add5", ADD, 2, 1); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(add1, 0, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(add2, 0, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(add3, 0, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(add4, 0, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add5, 0, net_output, 0); | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
/// | |||
/// SerialGraph: Buffer pool size only can contain one prefetch node | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |____w1_____|__| | |||
/// | |||
/// |____w2_____|__| | |||
/// | |||
/// |____w3_____|__| | |||
/// | |||
/// |______w4______| | |||
/// | |||
/// |______w5______| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 2048; | |||
auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||
auto add2 = builder.AddNode("add2", ADD, 2, 1); | |||
auto add3 = builder.AddNode("add3", ADD, 2, 1); | |||
auto add4 = builder.AddNode("add4", ADD, 2, 1); | |||
auto add5 = builder.AddNode("add5", ADD, 2, 1); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(add1, 0, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(add2, 0, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(add3, 0, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(add4, 0, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add5, 0, net_output, 0); | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
/// | |||
/// GraphWithMultiPrefetch: Calc node with more prefetch node | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 | |||
/// \ / \ / \ / | |||
/// \ / \ / \ / | |||
/// \ / \ / \ / | |||
/// add1 ------ c ------- add2 ----- c ----- add3 | |||
/// | | | | |||
/// | | | | |||
/// --------------- net_output ------------ | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||
auto add2 = builder.AddNode("add2", ADD, 2, 1); | |||
auto add3 = builder.AddNode("add3", ADD, 2, 1); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 3, 0); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch2, 0, add1, 1); | |||
builder.AddDataEdge(prefetch3, 0, add2, 0); | |||
builder.AddDataEdge(prefetch4, 0, add2, 1); | |||
builder.AddDataEdge(const1, 0, add3, 0); | |||
builder.AddDataEdge(prefetch5, 0, add3, 1); | |||
builder.AddDataEdge(add1, 0, net_output, 0); | |||
builder.AddDataEdge(add2, 0, net_output, 1); | |||
builder.AddDataEdge(add3, 0, net_output, 2); | |||
builder.AddControlEdge(add1, add2); | |||
builder.AddControlEdge(add2, add3); | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
/// | |||
/// GraphWithSubgraph: Calc node in different subgraph | |||
/// | |||
/// | |||
/// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output | |||
/// | |||
/// | |||
/// Subgraph1: Subgraph2: | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
// Subgraph1 | |||
auto subgraph_builder1 = InnerGraphBuilder("Subgraph1"); | |||
auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = subgraph_builder1.AddNode("w3", VARIABLE, 0, 1); | |||
auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch3 = subgraph_builder1.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); | |||
auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0); | |||
auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1); | |||
auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1); | |||
auto add3 = subgraph_builder1.AddNode("add3", ADD, 2, 1); | |||
subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0); | |||
subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0); | |||
subgraph_builder1.AddDataEdge(w3, 0, prefetch3, 0); | |||
subgraph_builder1.AddDataEdge(const1, 0, add1, 0); | |||
subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1); | |||
subgraph_builder1.AddDataEdge(add1, 0, add2, 0); | |||
subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1); | |||
subgraph_builder1.AddDataEdge(add2, 0, add3, 0); | |||
subgraph_builder1.AddDataEdge(prefetch3, 0, add3, 1); | |||
subgraph_builder1.AddDataEdge(add3, 0, subgraph1_out, 0); | |||
auto subgraph1 = subgraph_builder1.GetGraph(); | |||
for (auto &node : subgraph1->GetDirectNode()) { | |||
node->SetOwnerComputeGraph(subgraph1); | |||
} | |||
// Subgraph2 | |||
auto subgraph_builder2 = InnerGraphBuilder("Subgraph2"); | |||
auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1); | |||
auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1); | |||
auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1); | |||
auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1); | |||
auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1); | |||
subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0); | |||
subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0); | |||
subgraph_builder2.AddDataEdge(data1, 0, add4, 0); | |||
subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1); | |||
subgraph_builder2.AddDataEdge(add4, 0, add5, 0); | |||
subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1); | |||
subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0); | |||
auto subgraph2 = subgraph_builder2.GetGraph(); | |||
for (auto &node : subgraph2->GetDirectNode()) { | |||
node->SetOwnerComputeGraph(subgraph2); | |||
} | |||
// root graph | |||
auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1); | |||
auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(call_node1, 0, call_node2, 0); | |||
builder.AddDataEdge(call_node2, 0, net_output, 0); | |||
auto compute_graph = builder.GetGraph(); | |||
call_node1->SetOwnerComputeGraph(compute_graph); | |||
call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName()); | |||
call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); | |||
call_node2->SetOwnerComputeGraph(compute_graph); | |||
call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName()); | |||
call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName()); | |||
subgraph1->SetParentNode(call_node1); | |||
subgraph1->SetParentGraph(compute_graph); | |||
subgraph2->SetParentNode(call_node2); | |||
subgraph2->SetParentGraph(compute_graph); | |||
compute_graph->AddSubGraph(subgraph1); | |||
compute_graph->AddSubGraph(subgraph2); | |||
return compute_graph; | |||
} | |||
/// | |||
/// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency | |||
/// | |||
/// | |||
/// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output | |||
/// | |||
/// | |||
/// Subgraph1: Subgraph2: | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
// Subgraph1 | |||
auto subgraph_builder1 = InnerGraphBuilder("Subgraph1"); | |||
auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1); | |||
auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); | |||
auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0); | |||
auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1); | |||
auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1); | |||
subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0); | |||
subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0); | |||
subgraph_builder1.AddDataEdge(const1, 0, add1, 0); | |||
subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1); | |||
subgraph_builder1.AddDataEdge(add1, 0, add2, 0); | |||
subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1); | |||
subgraph_builder1.AddDataEdge(add2, 0, subgraph1_out, 0); | |||
auto subgraph1 = subgraph_builder1.GetGraph(); | |||
for (auto &node : subgraph1->GetDirectNode()) { | |||
node->SetOwnerComputeGraph(subgraph1); | |||
} | |||
// Subgraph2 | |||
auto subgraph_builder2 = InnerGraphBuilder("Subgraph2"); | |||
auto w3 = subgraph_builder2.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1); | |||
auto prefetch3 = subgraph_builder2.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto add3 = subgraph_builder2.AddNode("add3", ADD, 2, 1); | |||
auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1); | |||
auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1); | |||
auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1); | |||
auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1); | |||
subgraph_builder2.AddDataEdge(w3, 0, prefetch3, 0); | |||
subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0); | |||
subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0); | |||
subgraph_builder2.AddDataEdge(data1, 0, add3, 0); | |||
subgraph_builder2.AddDataEdge(prefetch3, 0, add3, 1); | |||
subgraph_builder2.AddDataEdge(add3, 0, add4, 0); | |||
subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1); | |||
subgraph_builder2.AddDataEdge(add4, 0, add5, 0); | |||
subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1); | |||
subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0); | |||
auto subgraph2 = subgraph_builder2.GetGraph(); | |||
for (auto &node : subgraph2->GetDirectNode()) { | |||
node->SetOwnerComputeGraph(subgraph2); | |||
} | |||
// root graph | |||
auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1); | |||
auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0); | |||
auto net_output = subgraph_builder2.AddNode("net_output", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(call_node1, 0, call_node2, 0); | |||
builder.AddDataEdge(call_node2, 0, net_output, 0); | |||
auto compute_graph = builder.GetGraph(); | |||
call_node1->SetOwnerComputeGraph(compute_graph); | |||
call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName()); | |||
call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); | |||
call_node2->SetOwnerComputeGraph(compute_graph); | |||
call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName()); | |||
call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName()); | |||
subgraph1->SetParentNode(call_node1); | |||
subgraph1->SetParentGraph(compute_graph); | |||
subgraph2->SetParentNode(call_node2); | |||
subgraph2->SetParentGraph(compute_graph); | |||
compute_graph->AddSubGraph(subgraph1); | |||
compute_graph->AddSubGraph(subgraph2); | |||
return compute_graph; | |||
} | |||
/// | |||
/// BuildGraphWithMultiBatch: Different batch label | |||
/// | |||
/// | |||
/// batch_label_128 | |||
/// | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | |||
/// / / / / / / \ | |||
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \ | |||
/// const1 switch_false / / / / / \ | |||
/// \ / / / / / / \ | |||
/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output | |||
/// / \ \ \ \ \ \ / | |||
/// const2 switch_true \ \ \ \ \ / | |||
/// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 / | |||
/// \ \ \ \ \ \ / | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | |||
/// | |||
/// batch_label_256 | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto const2 = builder.AddNode("const2", CONSTANTOP, 0, 1); | |||
auto switch1 = builder.AddNode("switch1", SWITCH, 2, 2); | |||
auto switch_false = builder.AddNode("switch_false", IDENTITY, 1, 1); | |||
auto switch_true = builder.AddNode("switch_true", IDENTITY, 1, 1); | |||
auto merge1 = builder.AddNode("merge1", MERGE, 2, 2); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(const1, 0, switch1, 0); | |||
builder.AddDataEdge(const2, 0, switch1, 1); | |||
builder.AddDataEdge(switch1, 0, switch_false, 0); | |||
builder.AddDataEdge(switch1, 1, switch_true, 0); | |||
builder.AddDataEdge(merge1, 0, net_output, 0); | |||
std::string batch_label_128 = "batch_128"; | |||
std::string batch_label_256 = "batch_256"; | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
{ | |||
auto prefetch1 = builder.AddNode("batch_label_128/prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); | |||
auto prefetch2 = builder.AddNode("batch_label_128/prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); | |||
auto prefetch3 = builder.AddNode("batch_label_128/prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); | |||
auto prefetch4 = builder.AddNode("batch_label_128/prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128); | |||
auto prefetch5 = builder.AddNode("batch_label_128/prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128); | |||
auto add1 = builder.AddNode("batch_label_128/add1", ADD, 2, 1); | |||
SetBatchLabel(add1, batch_label_128); | |||
auto add2 = builder.AddNode("batch_label_128/add2", ADD, 2, 1); | |||
SetBatchLabel(add2, batch_label_128); | |||
auto add3 = builder.AddNode("batch_label_128/add3", ADD, 2, 1); | |||
SetBatchLabel(add3, batch_label_128); | |||
auto add4 = builder.AddNode("batch_label_128/add4", ADD, 2, 1); | |||
SetBatchLabel(add4, batch_label_128); | |||
auto add5 = builder.AddNode("batch_label_128/add5", ADD, 2, 1); | |||
SetBatchLabel(add5, batch_label_128); | |||
auto const1 = builder.AddNode("batch_label_128/const1", CONSTANTOP, 0, 1); | |||
SetBatchLabel(const1, batch_label_128); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(add1, 0, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(add2, 0, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(add3, 0, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(add4, 0, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add5, 0, merge1, 0); | |||
builder.AddControlEdge(switch_false, const1); | |||
} | |||
{ | |||
auto prefetch1 = builder.AddNode("batch_label_256/prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); | |||
auto prefetch2 = builder.AddNode("batch_label_256/prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); | |||
auto prefetch3 = builder.AddNode("batch_label_256/prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); | |||
auto prefetch4 = builder.AddNode("batch_label_256/prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256); | |||
auto prefetch5 = builder.AddNode("batch_label_256/prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256); | |||
auto add1 = builder.AddNode("batch_label_256/add1", ADD, 2, 1); | |||
SetBatchLabel(add1, batch_label_256); | |||
auto add2 = builder.AddNode("batch_label_256/add2", ADD, 2, 1); | |||
SetBatchLabel(add2, batch_label_256); | |||
auto add3 = builder.AddNode("batch_label_256/add3", ADD, 2, 1); | |||
SetBatchLabel(add3, batch_label_256); | |||
auto add4 = builder.AddNode("batch_label_256/add4", ADD, 2, 1); | |||
SetBatchLabel(add4, batch_label_256); | |||
auto add5 = builder.AddNode("batch_label_256/add5", ADD, 2, 1); | |||
SetBatchLabel(add5, batch_label_256); | |||
auto const1 = builder.AddNode("batch_label_256/const1", CONSTANTOP, 0, 1); | |||
SetBatchLabel(const1, batch_label_128); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(add1, 0, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(add2, 0, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(add3, 0, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(add4, 0, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add5, 0, merge1, 1); | |||
builder.AddControlEdge(switch_true, const1); | |||
} | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
/// | |||
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// / \ / \ / \ / \ / | |||
/// / \ / \ / \ / \ / | |||
/// const1 ----- add1 add2 add3 add4 add5 | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// -------------- net_output --------------- | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); | |||
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||
auto add2 = builder.AddNode("add2", ADD, 2, 1); | |||
auto add3 = builder.AddNode("add3", ADD, 2, 1); | |||
auto add4 = builder.AddNode("add4", ADD, 2, 1); | |||
auto add5 = builder.AddNode("add5", ADD, 2, 1); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(prefetch1, 0, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(prefetch2, 0, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(prefetch3, 0, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(prefetch4, 0, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add1, 0, net_output, 0); | |||
builder.AddDataEdge(add2, 0, net_output, 1); | |||
builder.AddDataEdge(add3, 0, net_output, 2); | |||
builder.AddDataEdge(add4, 0, net_output, 3); | |||
builder.AddDataEdge(add5, 0, net_output, 4); | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
/// | |||
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ / \ / \ / \ / \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// / \ / \ / \ / \ / | |||
/// / \ / \ / \ / \ / | |||
/// const1 ----- add1 add2 add3 add4 add5 | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// -------------- net_output --------------- | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiInputOutputPrefetch() { | |||
auto builder = InnerGraphBuilder(graph_name_); | |||
auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); | |||
auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); | |||
auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); | |||
auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); | |||
auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); | |||
const int64_t buffer_pool_id = 0; | |||
const int64_t buffer_pool_size = 5600; | |||
auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 2, 2); | |||
SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500, 500}); | |||
auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 2, 2); | |||
SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500, 500}); | |||
auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 2, 2); | |||
SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500, 1024}); | |||
auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 2, 2); | |||
SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024, 1024}); | |||
auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); | |||
SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); | |||
auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); | |||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||
auto add2 = builder.AddNode("add2", ADD, 2, 1); | |||
auto add3 = builder.AddNode("add3", ADD, 2, 1); | |||
auto add4 = builder.AddNode("add4", ADD, 2, 1); | |||
auto add5 = builder.AddNode("add5", ADD, 2, 1); | |||
auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0); | |||
builder.AddDataEdge(w1, 0, prefetch1, 0); | |||
builder.AddDataEdge(w2, 0, prefetch1, 1); | |||
builder.AddDataEdge(w2, 0, prefetch2, 0); | |||
builder.AddDataEdge(w3, 0, prefetch2, 1); | |||
builder.AddDataEdge(w3, 0, prefetch3, 0); | |||
builder.AddDataEdge(w4, 0, prefetch3, 1); | |||
builder.AddDataEdge(w4, 0, prefetch4, 0); | |||
builder.AddDataEdge(w5, 0, prefetch4, 1); | |||
builder.AddDataEdge(w5, 0, prefetch5, 0); | |||
builder.AddDataEdge(const1, 0, add1, 0); | |||
builder.AddDataEdge(prefetch1, 0, add1, 1); | |||
builder.AddDataEdge(prefetch1, 1, add2, 0); | |||
builder.AddDataEdge(prefetch2, 0, add2, 1); | |||
builder.AddDataEdge(prefetch2, 1, add3, 0); | |||
builder.AddDataEdge(prefetch3, 0, add3, 1); | |||
builder.AddDataEdge(prefetch3, 1, add4, 0); | |||
builder.AddDataEdge(prefetch4, 0, add4, 1); | |||
builder.AddDataEdge(prefetch4, 1, add5, 0); | |||
builder.AddDataEdge(prefetch5, 0, add5, 1); | |||
builder.AddDataEdge(add1, 0, net_output, 0); | |||
builder.AddDataEdge(add2, 0, net_output, 1); | |||
builder.AddDataEdge(add3, 0, net_output, 2); | |||
builder.AddDataEdge(add4, 0, net_output, 3); | |||
builder.AddDataEdge(add5, 0, net_output, 4); | |||
auto compute_graph = builder.GetGraph(); | |||
return compute_graph; | |||
} | |||
} // namespace ut | |||
} // namespace ge |
@@ -0,0 +1,279 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_ | |||
#define GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_ | |||
#include <string> | |||
#include <vector> | |||
#include "graph/compute_graph.h" | |||
#include "graph/graph.h" | |||
#include "graph/node.h" | |||
namespace ge { | |||
namespace ut { | |||
class BufferPoolGraphBuilder { | |||
public: | |||
explicit BufferPoolGraphBuilder(const std::string &name = "BufferPoolGraph"); | |||
~BufferPoolGraphBuilder() {} | |||
class InnerGraphBuilder { | |||
public: | |||
explicit InnerGraphBuilder(const std::string &name); | |||
~InnerGraphBuilder() {} | |||
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||
ComputeGraphPtr GetGraph() { | |||
graph_->TopologicalSorting(); | |||
return graph_; | |||
} | |||
private: | |||
ComputeGraphPtr graph_; | |||
}; | |||
/// | |||
/// Normal graph | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildNormalGraph(); | |||
/// | |||
/// Normal graph with multi buffer pool | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// (pool0) (pool1) (pool0) (pool0) (pool1) | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w3__|_________| | |||
/// |_____w4_____|___________| | |||
/// | |||
/// |___w2__|_____w5___|_____| | |||
/// | |||
ComputeGraphPtr BuildNormalGraphWithMultiBufferPool(); | |||
/// | |||
/// SerialGraph: Buffer pool size only can contain one prefetch node | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |____w1_____|__| | |||
/// | |||
/// |____w2_____|__| | |||
/// | |||
/// |____w3_____|__| | |||
/// | |||
/// |______w4______| | |||
/// | |||
/// |______w5______| | |||
/// | |||
ComputeGraphPtr BuildSerialGraph(); | |||
/// | |||
/// GraphWithMultiPrefetch: Calc node with more prefetch node | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 | |||
/// \ / \ / \ / | |||
/// \ / \ / \ / | |||
/// \ / \ / \ / | |||
/// add1 ------ c ------- add2 ----- c ----- add3 | |||
/// | | | | |||
/// | | | | |||
/// --------------- net_output ------------ | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildGraphWithMultiPrefetch(); | |||
/// | |||
/// GraphWithSubgraph: Calc node in different subgraph | |||
/// | |||
/// | |||
/// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output | |||
/// | |||
/// | |||
/// Subgraph1: Subgraph2: | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildGraphWithSubgraph(); | |||
/// | |||
/// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency | |||
/// | |||
/// | |||
/// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output | |||
/// | |||
/// | |||
/// Subgraph1: Subgraph2: | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// \ \ \ \ \ | |||
/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildSubgraphWithInnerDependency(); | |||
/// | |||
/// BuildGraphWithMultiBatch: Different batch label | |||
/// | |||
/// | |||
/// batch_label_128 | |||
/// | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | |||
/// / / / / / / \ | |||
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \ | |||
/// const1 switch_false / / / / / \ | |||
/// \ / / / / / / \ | |||
/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output | |||
/// / \ \ \ \ \ \ / | |||
/// const2 switch_true \ \ \ \ \ / | |||
/// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 / | |||
/// \ \ \ \ \ \ / | |||
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | |||
/// | |||
/// batch_label_256 | |||
/// | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildGraphWithMultiBatch(); | |||
/// | |||
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ \ \ \ \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// / \ / \ / \ / \ / | |||
/// / \ / \ / \ / \ / | |||
/// const1 ----- add1 add2 add3 add4 add5 | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// -------------- net_output --------------- | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildGraphWithMultiOutputPrefetch(); | |||
/// | |||
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output | |||
/// | |||
/// w1 w2 w3 w4 w5 | |||
/// \ / \ / \ / \ / \ | |||
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | |||
/// / \ / \ / \ / \ / | |||
/// / \ / \ / \ / \ / | |||
/// const1 ----- add1 add2 add3 add4 add5 | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// | \ | / | | |||
/// -------------- net_output --------------- | |||
/// | |||
/// Memory distribution: | |||
/// | |||
/// |___w1__|__w2__|__w3__|__| | |||
/// | |||
/// |_____w4_____|_____w5____| | |||
/// | |||
ComputeGraphPtr BuildGraphWithMultiInputOutputPrefetch(); | |||
void SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, const std::string &batch_label = ""); | |||
void SetBatchLabel(NodePtr &node, const std::string &batch_label = ""); | |||
void SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size = {1024}); | |||
void SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes = {1024}); | |||
void SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size, | |||
const std::vector<int64_t> &mem_size = {1024}, | |||
const std::vector<int64_t> &ws_bytes = {1024}, | |||
const std::string &batch_label = ""); | |||
private: | |||
std::string graph_name_; | |||
}; | |||
} // namespace ut | |||
} // namespace ge | |||
#endif // GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_ |