| @@ -375,6 +375,7 @@ set(TRAIN_SRC_LIST | |||||
| "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" | "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" | ||||
| "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" | "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" | ||||
| "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" | "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" | ||||
| "hybrid/node_executor/host_cpu/kernel/data_kernel.cc" | |||||
| "hybrid/node_executor/controlop/control_op_executor.cc" | "hybrid/node_executor/controlop/control_op_executor.cc" | ||||
| "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | ||||
| "hybrid/node_executor/hccl/hccl_node_executor.cc" | "hybrid/node_executor/hccl/hccl_node_executor.cc" | ||||
| @@ -388,6 +388,7 @@ REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | |||||
| REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | ||||
| REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
| @@ -104,6 +104,7 @@ set(SRC_LIST | |||||
| "../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" | "../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" | ||||
| "../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" | "../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" | ||||
| "../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" | "../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" | ||||
| "../hybrid/node_executor/host_cpu/kernel/data_kernel.cc" | |||||
| "../hybrid/node_executor/controlop/control_op_executor.cc" | "../hybrid/node_executor/controlop/control_op_executor.cc" | ||||
| "../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | "../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | ||||
| "../hybrid/node_executor/rts/rts_node_executor.cc" | "../hybrid/node_executor/rts/rts_node_executor.cc" | ||||
| @@ -95,6 +95,7 @@ local_ge_executor_src_files := \ | |||||
| ../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ | ../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ | ||||
| ../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ | ../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ | ||||
| ../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ | ../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ | ||||
| ../hybrid/node_executor/host_cpu/kernel/data_kernel.cc \ | |||||
| ../hybrid/node_executor/controlop/control_op_executor.cc \ | ../hybrid/node_executor/controlop/control_op_executor.cc \ | ||||
| ../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ | ../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ | ||||
| ../hybrid/node_executor/rts/rts_node_executor.cc \ | ../hybrid/node_executor/rts/rts_node_executor.cc \ | ||||
| @@ -300,6 +300,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ | hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ | ||||
| hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ | hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ | ||||
| hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ | hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ | ||||
| hybrid/node_executor/host_cpu/kernel/data_kernel.cc \ | |||||
| hybrid/node_executor/controlop/control_op_executor.cc \ | hybrid/node_executor/controlop/control_op_executor.cc \ | ||||
| hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ | hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ | ||||
| hybrid/node_executor/hccl/hccl_node_executor.cc \ | hybrid/node_executor/hccl/hccl_node_executor.cc \ | ||||
| @@ -60,9 +60,14 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr | |||||
| return FAILED); | return FAILED); | ||||
| ge::ConstGeTensorDescPtr tensor_desc = n->GetOpDesc()->GetOutputDescPtr(0); | ge::ConstGeTensorDescPtr tensor_desc = n->GetOpDesc()->GetOutputDescPtr(0); | ||||
| GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
| rtMemType_t memory_type = RT_MEMORY_HBM; | |||||
| uint32_t mem_type = 0; | |||||
| if (AttrUtils::GetInt(n->GetOpDesc(), ATTR_OUTPUT_MEMORY_TYPE, mem_type) && (mem_type == 1)) { | |||||
| memory_type = RT_MEMORY_RDMA_HBM; | |||||
| } | |||||
| if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { | if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { | ||||
| GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
| VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, RT_MEMORY_HBM)); | |||||
| VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, memory_type)); | |||||
| GE_IF_BOOL_EXEC(n->GetType() == VARIABLE, | GE_IF_BOOL_EXEC(n->GetType() == VARIABLE, | ||||
| GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); | GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); | ||||
| GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) | GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) | ||||
| @@ -70,7 +75,6 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr | |||||
| } | } | ||||
| uint8_t *dev_ptr = nullptr; | uint8_t *dev_ptr = nullptr; | ||||
| rtMemType_t memory_type = RT_MEMORY_HBM; | |||||
| GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) | GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) | ||||
| ->GetVarAddr(node_name, *tensor_desc, &dev_ptr, memory_type)); | ->GetVarAddr(node_name, *tensor_desc, &dev_ptr, memory_type)); | ||||
| vector<int64_t> output_list = n->GetOpDesc()->GetOutputOffset(); | vector<int64_t> output_list = n->GetOpDesc()->GetOutputOffset(); | ||||
| @@ -15,18 +15,10 @@ | |||||
| */ | */ | ||||
| #include "graph/load/new_model_manager/model_utils.h" | #include "graph/load/new_model_manager/model_utils.h" | ||||
| #include <string> | #include <string> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "runtime/base.h" | |||||
| #include "runtime/kernel.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #define VALIDATE_MEM_RANGE(OP, SIZE, OFFSET) \ | #define VALIDATE_MEM_RANGE(OP, SIZE, OFFSET) \ | ||||
| @@ -342,8 +334,8 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
| int64_t input_offset = v_input_offset[non_const_index]; | int64_t input_offset = v_input_offset[non_const_index]; | ||||
| non_const_index++; | non_const_index++; | ||||
| GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), | GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), | ||||
| VALIDATE_MEM_RANGE(op_desc, model_param.var_size, input_offset - model_param.logic_var_base); | |||||
| uint8_t *variable_addr = model_param.var_base + input_offset - model_param.logic_var_base; | |||||
| uint8_t *variable_addr = nullptr; | |||||
| GE_CHK_STATUS_EXEC(GetVarAddr(model_param, op_desc, input_offset, variable_addr), return {}); | |||||
| v_input_data_addr.push_back(variable_addr); | v_input_data_addr.push_back(variable_addr); | ||||
| GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", | GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", | ||||
| model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); | model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); | ||||
| @@ -380,6 +372,27 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
| return v_input_data_addr; | return v_input_data_addr; | ||||
| } | } | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Get variable address. | |||||
| /// @return Status | |||||
| /// | |||||
| Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset, | |||||
| uint8_t *&var_addr) { | |||||
| if (ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset) == RT_MEMORY_RDMA_HBM) { | |||||
| if (offset < 0) { | |||||
| GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset)); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| var_addr = reinterpret_cast<uint8_t *>(offset); | |||||
| GE_CHECK_NOTNULL(var_addr); | |||||
| } else { | |||||
| VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); | |||||
| var_addr = model_param.var_base + offset - model_param.logic_var_base; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Get output data address. | /// @brief Get output data address. | ||||
| @@ -405,8 +418,8 @@ vector<void *> ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C | |||||
| } | } | ||||
| for (size_t i = 0; i < outputs_size; ++i) { | for (size_t i = 0; i < outputs_size; ++i) { | ||||
| GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), | GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), | ||||
| VALIDATE_MEM_RANGE(op_desc, model_param.var_size, v_output_offset[i] - model_param.logic_var_base); | |||||
| uint8_t *variable_addr = model_param.var_base + v_output_offset[i] - model_param.logic_var_base; | |||||
| uint8_t *variable_addr = nullptr; | |||||
| GE_CHK_STATUS_EXEC(GetVarAddr(model_param, op_desc, v_output_offset[i], variable_addr), return {}); | |||||
| v_output_data_addr.push_back(variable_addr); | v_output_data_addr.push_back(variable_addr); | ||||
| GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]", | GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]", | ||||
| model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); | model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); | ||||
| @@ -107,6 +107,15 @@ class ModelUtils { | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr); | static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr); | ||||
| private: | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Get variable address. | |||||
| /// @return Status | |||||
| /// | |||||
| static Status GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset, | |||||
| uint8_t *&var_addr); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -16,17 +16,10 @@ | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include <utility> | |||||
| #include "common/l2_cache_optimize.h" | |||||
| #include "common/types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| #include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| using std::map; | using std::map; | ||||
| @@ -37,7 +30,7 @@ namespace ge { | |||||
| VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {} | VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {} | ||||
| VarResource::~VarResource() { | VarResource::~VarResource() { | ||||
| var_offset_set_.clear(); | |||||
| var_offset_map_.clear(); | |||||
| var_addr_mgr_map_.clear(); | var_addr_mgr_map_.clear(); | ||||
| cur_var_tensor_desc_map_.clear(); | cur_var_tensor_desc_map_.clear(); | ||||
| var_broad_cast_info_.clear(); | var_broad_cast_info_.clear(); | ||||
| @@ -91,8 +84,10 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen | |||||
| std::string var_key = VarKey(var_name, tensor_desc); | std::string var_key = VarKey(var_name, tensor_desc); | ||||
| GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str()); | GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str()); | ||||
| if (var_addr_mgr_map_.count(var_key) == 0) { | if (var_addr_mgr_map_.count(var_key) == 0) { | ||||
| uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() + | |||||
| static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address)); | |||||
| uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address)); | |||||
| if (memory_type != RT_MEMORY_RDMA_HBM) { | |||||
| logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase(); | |||||
| } | |||||
| GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(), | GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(), | ||||
| TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), | TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str()); | TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str()); | ||||
| @@ -102,7 +97,7 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen | |||||
| var_addr_mgr.tensor_desc = tensor_desc; | var_addr_mgr.tensor_desc = tensor_desc; | ||||
| var_addr_mgr.memory_type = memory_type; | var_addr_mgr.memory_type = memory_type; | ||||
| var_addr_mgr_map_[var_key] = var_addr_mgr; | var_addr_mgr_map_[var_key] = var_addr_mgr; | ||||
| var_offset_set_.insert(logic_address); | |||||
| var_offset_map_[logic_address] = memory_type; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -211,7 +206,14 @@ ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_na | |||||
| return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); | return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); | ||||
| } | } | ||||
| bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_set_.count(offset) > 0; } | |||||
| bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | |||||
| rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | |||||
| if (var_offset_map_.count(offset) > 0) { | |||||
| return var_offset_map_[offset]; | |||||
| } | |||||
| return RT_MEMORY_HBM; | |||||
| } | |||||
| VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) { | VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) { | ||||
| auto iter = var_to_trans_road_.find(var_name); | auto iter = var_to_trans_road_.find(var_name); | ||||
| @@ -252,7 +254,19 @@ Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t gr | |||||
| MemResource::MemResource() : total_size_(0), var_mem_size_(0) {} | MemResource::MemResource() : total_size_(0), var_mem_size_(0) {} | ||||
| Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) { | |||||
| MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) { | |||||
| switch (mem_type) { | |||||
| case RT_MEMORY_HBM: | |||||
| return new (std::nothrow) HbmMemResource(); | |||||
| case RT_MEMORY_RDMA_HBM: | |||||
| return new (std::nothrow) RdmaMemResource(); | |||||
| default: | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, | |||||
| size_t &mem_offset) { | |||||
| size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; | size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; | ||||
| uint64_t real_size = size; | uint64_t real_size = size; | ||||
| total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize(); | total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize(); | ||||
| @@ -282,6 +296,19 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) { | |||||
| uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size); | |||||
| if (buffer == nullptr) { | |||||
| GELOGE(MEMALLOC_FAILED, "Failed to malloc rdma memory for node %s, size = %llu", var_name.c_str(), size); | |||||
| return MEMALLOC_FAILED; | |||||
| } | |||||
| address = reinterpret_cast<size_t>(reinterpret_cast<uintptr_t>(buffer)); | |||||
| var_mem_size_ += size; | |||||
| GELOGI("[IMAS]AssignVarMem Set session_%llu name[%s] output[%d] addr to [%p] size[%llu].", | |||||
| session_id, var_name.c_str(), 0, buffer, size); | |||||
| return SUCCESS; | |||||
| } | |||||
| uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; } | uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; } | ||||
| void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; | void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; | ||||
| @@ -428,7 +455,7 @@ Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
| MemResource *mem_resource = nullptr; | MemResource *mem_resource = nullptr; | ||||
| auto iter = mem_resource_map_.find(memory_type); | auto iter = mem_resource_map_.find(memory_type); | ||||
| if (iter == mem_resource_map_.end()) { | if (iter == mem_resource_map_.end()) { | ||||
| mem_resource = new (std::nothrow) MemResource(); | |||||
| mem_resource = MemResource::BuildMemResourceFromType(memory_type); | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | ||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| @@ -465,7 +492,7 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
| MemResource *mem_resource = nullptr; | MemResource *mem_resource = nullptr; | ||||
| auto it = mem_resource_map_.find(memory_type); | auto it = mem_resource_map_.find(memory_type); | ||||
| if (it == mem_resource_map_.end()) { | if (it == mem_resource_map_.end()) { | ||||
| mem_resource = new (std::nothrow) MemResource(); | |||||
| mem_resource = MemResource::BuildMemResourceFromType(memory_type); | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | ||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| @@ -629,6 +656,15 @@ bool VarManager::IsVarAddr(const int64_t &offset) { | |||||
| return var_resource_->IsVarAddr(offset); | return var_resource_->IsVarAddr(offset); | ||||
| } | } | ||||
| rtMemType_t VarManager::GetVarMemType(const int64_t &offset) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (var_resource_ == nullptr) { | |||||
| GELOGW("VarManager has not been init."); | |||||
| return RT_MEMORY_HBM; | |||||
| } | |||||
| return var_resource_->GetVarMemType(offset); | |||||
| } | |||||
| ge::Status VarManager::MallocVarMemory(size_t memory_size) { | ge::Status VarManager::MallocVarMemory(size_t memory_size) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| uint8_t *var_mem_base = nullptr; | uint8_t *var_mem_base = nullptr; | ||||
| @@ -654,12 +690,18 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { | |||||
| uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) { | uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| if (memory_type == RT_MEMORY_RDMA_HBM) { | |||||
| return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr(); | |||||
| } | |||||
| string memory_key = std::to_string(session_id_); | string memory_key = std::to_string(session_id_); | ||||
| return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key); | return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key); | ||||
| } | } | ||||
| uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { | uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| if (memory_type == RT_MEMORY_RDMA_HBM) { | |||||
| return logic_addr; | |||||
| } | |||||
| string mem_key = std::to_string(session_id_); | string mem_key = std::to_string(session_id_); | ||||
| uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key); | uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key); | ||||
| if (mem_base == nullptr) { | if (mem_base == nullptr) { | ||||
| @@ -158,13 +158,15 @@ class VarResource { | |||||
| bool IsVarAddr(const int64_t &offset); | bool IsVarAddr(const int64_t &offset); | ||||
| rtMemType_t GetVarMemType(const int64_t &offset); | |||||
| std::unordered_map<std::string, ge::GeTensorDesc> GetAllVarDesc() const { return cur_var_tensor_desc_map_; } | std::unordered_map<std::string, ge::GeTensorDesc> GetAllVarDesc() const { return cur_var_tensor_desc_map_; } | ||||
| private: | private: | ||||
| std::string VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | std::string VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | ||||
| uint64_t session_id_; | uint64_t session_id_; | ||||
| std::unordered_set<uint64_t> var_offset_set_; | |||||
| std::unordered_map<uint64_t, rtMemType_t> var_offset_map_; | |||||
| std::unordered_map<std::string, VarAddrMgr> var_addr_mgr_map_; | std::unordered_map<std::string, VarAddrMgr> var_addr_mgr_map_; | ||||
| std::unordered_map<std::string, ge::GeTensorDesc> cur_var_tensor_desc_map_; | std::unordered_map<std::string, ge::GeTensorDesc> cur_var_tensor_desc_map_; | ||||
| std::unordered_map<std::string, std::vector<TransNodeInfo>> var_to_trans_road_; | std::unordered_map<std::string, std::vector<TransNodeInfo>> var_to_trans_road_; | ||||
| @@ -176,19 +178,36 @@ class VarResource { | |||||
| class MemResource { | class MemResource { | ||||
| public: | public: | ||||
| MemResource(); | MemResource(); | ||||
| ~MemResource() = default; | |||||
| virtual ~MemResource() = default; | |||||
| static MemResource *BuildMemResourceFromType(rtMemType_t mem_type); | |||||
| Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset); | |||||
| virtual Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) = 0; | |||||
| uint64_t GetVarMemSize() const; | uint64_t GetVarMemSize() const; | ||||
| void UpdateVarMemSize(int64_t mem_size); | void UpdateVarMemSize(int64_t mem_size); | ||||
| private: | |||||
| protected: | |||||
| uint64_t total_size_; | uint64_t total_size_; | ||||
| uint64_t var_mem_size_; | uint64_t var_mem_size_; | ||||
| }; | }; | ||||
| class HbmMemResource : public MemResource { | |||||
| public: | |||||
| HbmMemResource() = default; | |||||
| ~HbmMemResource() override = default; | |||||
| Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) override; | |||||
| }; | |||||
| class RdmaMemResource : public MemResource { | |||||
| public: | |||||
| RdmaMemResource() = default; | |||||
| ~RdmaMemResource() override = default; | |||||
| Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) override; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | ||||
| public: | public: | ||||
| static VarManager *Instance(uint64_t session_id); | static VarManager *Instance(uint64_t session_id); | ||||
| @@ -275,6 +294,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| bool IsVarAddr(const int64_t &offset); | bool IsVarAddr(const int64_t &offset); | ||||
| rtMemType_t GetVarMemType(const int64_t &offset); | |||||
| uint8_t *GetVarMemoryBase(rtMemType_t memory_type); | uint8_t *GetVarMemoryBase(rtMemType_t memory_type); | ||||
| uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type); | uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type); | ||||
| @@ -53,6 +53,10 @@ class RdmaPoolAllocator { | |||||
| Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size); | Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size); | ||||
| uint8_t *GetRdmaBaseAddr() { return rdma_base_addr_; } | |||||
| size_t GetRdmaMemSize() { return rdma_mem_size_; } | |||||
| private: | private: | ||||
| void MergeBlocks(Block *dst, Block *src); | void MergeBlocks(Block *dst, Block *src); | ||||
| @@ -213,6 +213,7 @@ std::string DynamicShapePartitioner::DebugString() const { | |||||
| size_t data = 0; | size_t data = 0; | ||||
| size_t netoutput = 0; | size_t netoutput = 0; | ||||
| size_t is_inputnode = 0; | size_t is_inputnode = 0; | ||||
| size_t stage = 0; | |||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << "All unknown shape nodes:" << std::endl; | ss << "All unknown shape nodes:" << std::endl; | ||||
| for (const auto &node : unknown_shape_nodes_) { | for (const auto &node : unknown_shape_nodes_) { | ||||
| @@ -229,10 +230,13 @@ std::string DynamicShapePartitioner::DebugString() const { | |||||
| netoutput++; | netoutput++; | ||||
| } else if (cluster->IsInputNode()) { | } else if (cluster->IsInputNode()) { | ||||
| is_inputnode++; | is_inputnode++; | ||||
| } else if (cluster->IsIndependent()) { | |||||
| stage++; | |||||
| } | } | ||||
| } | } | ||||
| ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known | ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known | ||||
| << ", unknown:" << unknown << ", netoutput:" << netoutput << ", is_inputnode:" << is_inputnode << std::endl; | |||||
| << ", unknown:" << unknown << ", netoutput:" << netoutput << ", is_inputnode:" << is_inputnode | |||||
| << ", stage:" << stage << std::endl; | |||||
| for (const auto &cluster : unique_clusters_) { | for (const auto &cluster : unique_clusters_) { | ||||
| ss << " " << cluster->DebugString() << std::endl; | ss << " " << cluster->DebugString() << std::endl; | ||||
| } | } | ||||
| @@ -272,12 +276,15 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| Cluster::Type type = Cluster::DATA; | Cluster::Type type = Cluster::DATA; | ||||
| bool is_input = ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) && node->GetInNodes().empty(); | bool is_input = ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) && node->GetInNodes().empty(); | ||||
| REQUIRE_NOT_NULL(node->GetOpDesc(), "op_desc is null"); | |||||
| if (node->GetType() == DATA) { | if (node->GetType() == DATA) { | ||||
| type = Cluster::DATA; | type = Cluster::DATA; | ||||
| } else if (is_input) { | } else if (is_input) { | ||||
| type = Cluster::INPUT_NODE; | type = Cluster::INPUT_NODE; | ||||
| } else if (node->GetType() == NETOUTPUT) { | } else if (node->GetType() == NETOUTPUT) { | ||||
| type = Cluster::NETOUTPUT; | type = Cluster::NETOUTPUT; | ||||
| } else if ((node->GetType() == PARTITIONEDCALL) && (node->GetOpDesc()->HasAttr(ATTR_STAGE_LEVEL))) { | |||||
| type = Cluster::STAGE; | |||||
| } else if (unknown_shape_nodes_.count(node) > 0) { | } else if (unknown_shape_nodes_.count(node) > 0) { | ||||
| type = Cluster::UNKNOWN_SHAPE; | type = Cluster::UNKNOWN_SHAPE; | ||||
| } else { | } else { | ||||
| @@ -360,6 +367,9 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
| void DynamicShapePartitioner::MergeClustersUnknownShape() { | void DynamicShapePartitioner::MergeClustersUnknownShape() { | ||||
| // Merge unknown shape clusters | // Merge unknown shape clusters | ||||
| for (const auto &cluster : ordered_cluster_) { | for (const auto &cluster : ordered_cluster_) { | ||||
| if (cluster->IsIndependent()) { | |||||
| continue; | |||||
| } | |||||
| for (const auto &in_cluster : cluster->Inputs()) { | for (const auto &in_cluster : cluster->Inputs()) { | ||||
| if (!in_cluster->IsUnknownShape()) { | if (!in_cluster->IsUnknownShape()) { | ||||
| continue; | continue; | ||||
| @@ -379,6 +389,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||||
| void DynamicShapePartitioner::MergeClustersKnownShape() { | void DynamicShapePartitioner::MergeClustersKnownShape() { | ||||
| // Merge known shape clusters | // Merge known shape clusters | ||||
| for (const auto &cluster : ordered_cluster_) { | for (const auto &cluster : ordered_cluster_) { | ||||
| if (cluster->IsIndependent()) { | |||||
| continue; | |||||
| } | |||||
| if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { | if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { | ||||
| auto in_cluster = *(cluster->Inputs().begin()); | auto in_cluster = *(cluster->Inputs().begin()); | ||||
| in_cluster->Merge(cluster); | in_cluster->Merge(cluster); | ||||
| @@ -606,6 +619,7 @@ void Cluster::UpdateRank(size_t rank) { | |||||
| bool Cluster::IsData() const { return type_ == DATA; }; | bool Cluster::IsData() const { return type_ == DATA; }; | ||||
| bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; }; | bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; }; | ||||
| bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; }; | bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; }; | ||||
| bool Cluster::IsIndependent() const { return type_ == STAGE; }; | |||||
| bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; }; | bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; }; | ||||
| bool Cluster::IsInputNode() const { return type_ == INPUT_NODE; }; | bool Cluster::IsInputNode() const { return type_ == INPUT_NODE; }; | ||||
| bool Cluster::IsRefVariable() const { | bool Cluster::IsRefVariable() const { | ||||
| @@ -641,6 +655,9 @@ void Cluster::RemoveOutput(ClusterPtr out) { | |||||
| out->in_clusters_.end()); | out->in_clusters_.end()); | ||||
| }; | }; | ||||
| void Cluster::Merge(ClusterPtr other) { | void Cluster::Merge(ClusterPtr other) { | ||||
| if (other->IsIndependent()) { | |||||
| return; | |||||
| } | |||||
| nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end()); | nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end()); | ||||
| other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), shared_from_this()), | other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), shared_from_this()), | ||||
| other->in_clusters_.end()); | other->in_clusters_.end()); | ||||
| @@ -689,7 +706,9 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) { | |||||
| std::unordered_set<ClusterPtr> forward_reached_clusters; | std::unordered_set<ClusterPtr> forward_reached_clusters; | ||||
| std::unordered_set<ClusterPtr> backward_reached_clusters; | std::unordered_set<ClusterPtr> backward_reached_clusters; | ||||
| std::vector<ClusterPtr> path_clusters; | std::vector<ClusterPtr> path_clusters; | ||||
| if (other->IsIndependent()) { | |||||
| return path_clusters; | |||||
| } | |||||
| if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == | if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == | ||||
| other->out_clusters_.end()) { | other->out_clusters_.end()) { | ||||
| return path_clusters; | return path_clusters; | ||||
| @@ -772,7 +791,7 @@ Status Cluster::BuildFrame() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (IsData()) { | |||||
| if (IsData() || IsIndependent()) { | |||||
| for (const auto &anchor : node->GetAllOutDataAnchors()) { | for (const auto &anchor : node->GetAllOutDataAnchors()) { | ||||
| AddFrameOutput(anchor); | AddFrameOutput(anchor); | ||||
| } | } | ||||
| @@ -888,7 +907,7 @@ Status Cluster::CombinePartitionFrame() { | |||||
| } | } | ||||
| Status Cluster::BuildPartitionSubgraph() { | Status Cluster::BuildPartitionSubgraph() { | ||||
| if (IsData() || IsNetOutput()) { | |||||
| if (IsData() || IsNetOutput() || IsIndependent()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int64_t parent_node_index = 0; | int64_t parent_node_index = 0; | ||||
| @@ -32,7 +32,7 @@ class DynamicShapePartitioner { | |||||
| // DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. | // DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. | ||||
| class Cluster : public std::enable_shared_from_this<Cluster> { | class Cluster : public std::enable_shared_from_this<Cluster> { | ||||
| public: | public: | ||||
| enum Type { DATA, INPUT_NODE, NETOUTPUT, KNOWN_SHAPE, UNKNOWN_SHAPE }; | |||||
| enum Type { DATA, INPUT_NODE, NETOUTPUT, STAGE, KNOWN_SHAPE, UNKNOWN_SHAPE }; | |||||
| Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) | Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) | ||||
| : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { | : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { | ||||
| nodes_.push_back(node); | nodes_.push_back(node); | ||||
| @@ -45,6 +45,7 @@ class DynamicShapePartitioner { | |||||
| bool IsData() const; | bool IsData() const; | ||||
| bool IsKnownShape() const; | bool IsKnownShape() const; | ||||
| bool IsUnknownShape() const; | bool IsUnknownShape() const; | ||||
| bool IsIndependent() const; | |||||
| bool IsNetOutput() const; | bool IsNetOutput() const; | ||||
| std::vector<std::shared_ptr<Cluster>> Inputs() const; | std::vector<std::shared_ptr<Cluster>> Inputs() const; | ||||
| std::vector<std::shared_ptr<Cluster>> Outputs() const; | std::vector<std::shared_ptr<Cluster>> Outputs() const; | ||||
| @@ -25,6 +25,10 @@ | |||||
| #include "common/types.h" | #include "common/types.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::set<std::string> kSrcNodeTypes = { DATA, AIPPDATA, ANN_DATA }; | |||||
| } | |||||
| Status StagePartitioner::Partition() { | Status StagePartitioner::Partition() { | ||||
| GE_CHECK_NOTNULL(root_graph_); | GE_CHECK_NOTNULL(root_graph_); | ||||
| if (root_graph_->GetParentGraph() != nullptr) { | if (root_graph_->GetParentGraph() != nullptr) { | ||||
| @@ -37,6 +41,10 @@ Status StagePartitioner::Partition() { | |||||
| if (!AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, level)) { | if (!AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, level)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if ((kSrcNodeTypes.count(op_desc->GetType()) != 0) && node->GetInAllNodes().empty()) { | |||||
| continue; | |||||
| } | |||||
| GELOGD("original node %s for stage %u", node->GetName().c_str(), level); | |||||
| stage_nodes_[level].insert(node); | stage_nodes_[level].insert(node); | ||||
| } | } | ||||
| if (stage_nodes_.empty()) { | if (stage_nodes_.empty()) { | ||||
| @@ -54,6 +62,13 @@ Status StagePartitioner::Partition() { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| root_graph_->TopologicalSorting([](const NodePtr &a, const NodePtr &b) -> bool { | |||||
| uint32_t a_level = UINT32_MAX; | |||||
| (void)AttrUtils::GetInt(a->GetOpDesc(), ATTR_STAGE_LEVEL, a_level); | |||||
| uint32_t b_level = UINT32_MAX; | |||||
| (void)AttrUtils::GetInt(b->GetOpDesc(), ATTR_STAGE_LEVEL, b_level); | |||||
| return a_level < b_level; | |||||
| }); | |||||
| if (root_graph_->TopologicalSorting() != GRAPH_SUCCESS) { | if (root_graph_->TopologicalSorting() != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Topological sort for graph %s after stage partition failed, " | GELOGE(FAILED, "Topological sort for graph %s after stage partition failed, " | ||||
| "maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); | "maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); | ||||
| @@ -76,20 +91,26 @@ Status StagePartitioner::SplitStageLevel() { | |||||
| auto node = nodes.top(); | auto node = nodes.top(); | ||||
| nodes.pop(); | nodes.pop(); | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| if (node->GetOpDesc()->HasAttr(ATTR_STAGE_LEVEL) && (cur_stage_nodes.count(node) == 0)) { | |||||
| uint32_t tmp_level = cur_stage_level; | |||||
| (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, tmp_level); | |||||
| if (tmp_level != cur_stage_level) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (const auto &in_node : node->GetInAllNodes()) { | for (const auto &in_node : node->GetInAllNodes()) { | ||||
| if (visited_stage_nodes.count(in_node) != 0) { | if (visited_stage_nodes.count(in_node) != 0) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (!AttrUtils::SetInt(in_node->GetOpDesc(), ATTR_STAGE_LEVEL, cur_stage_level)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ATTR_STAGE_LEVEL on node %s failed.", in_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Mark stage_level node %s, stage_level=%u", in_node->GetName().c_str(), cur_stage_level); | |||||
| if ((kSrcNodeTypes.count(in_node->GetType()) != 0) && in_node->GetInAllNodes().empty()) { | |||||
| GELOGD("skip data node %s for stage %u", in_node->GetName().c_str(), cur_stage_level); | |||||
| continue; | |||||
| } | |||||
| nodes.push(in_node); | nodes.push(in_node); | ||||
| } | } | ||||
| if (!AttrUtils::SetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, cur_stage_level)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ATTR_STAGE_LEVEL on node %s failed.", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Mark stage_level node %s, stage_level=%u", node->GetName().c_str(), cur_stage_level); | |||||
| visited_stage_nodes.emplace(node); | visited_stage_nodes.emplace(node); | ||||
| } | } | ||||
| for (const auto &node : visited_stage_nodes) { | for (const auto &node : visited_stage_nodes) { | ||||
| @@ -219,6 +240,11 @@ NodePtr StagePartitioner::BuildSubgraphNode(const std::string &graph_name, const | |||||
| op_desc->AddSubgraphName("f"); | op_desc->AddSubgraphName("f"); | ||||
| op_desc->SetSubgraphInstanceName(0, graph_name); | op_desc->SetSubgraphInstanceName(0, graph_name); | ||||
| if (!AttrUtils::SetInt(op_desc, ATTR_STAGE_LEVEL, stage_info.stage_level)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ATTR_STAGE_LEVEL on node %s failed", op_desc->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| NodePtr subgraph_node = root_graph_->AddNode(op_desc); | NodePtr subgraph_node = root_graph_->AddNode(op_desc); | ||||
| if (subgraph_node == nullptr) { | if (subgraph_node == nullptr) { | ||||
| GELOGE(FAILED, "Add node %s failed.", graph_name.c_str()); | GELOGE(FAILED, "Add node %s failed.", graph_name.c_str()); | ||||
| @@ -142,17 +142,18 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node | |||||
| GE_CHECK_NOTNULL(in_node); | GE_CHECK_NOTNULL(in_node); | ||||
| // Need insert memcpy | // Need insert memcpy | ||||
| // 1. Const->NetOutput in subgraph | |||||
| // 1. Const->NetOutput in subgraph & parent graph is known | |||||
| // 2. AtomicOp->NetOutput in subgraph | // 2. AtomicOp->NetOutput in subgraph | ||||
| // 3. OutputContinuesRequiredOp->NetOutput in subgraph | // 3. OutputContinuesRequiredOp->NetOutput in subgraph | ||||
| // 4. Data->NetOutput in subgraph but parent_node is not while | // 4. Data->NetOutput in subgraph but parent_node is not while | ||||
| // 5. While->NetOutput in known subgraph | // 5. While->NetOutput in known subgraph | ||||
| std::string op_type; | std::string op_type; | ||||
| bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || | |||||
| bool insert_flag = | |||||
| (NodeUtils::GetConstOpType(in_node, op_type) && !graph->GetParentGraph()->GetGraphUnknownFlag()) || | |||||
| IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || | IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || | ||||
| ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || | ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || | ||||
| (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && | (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && | ||||
| (kWhileOpTypes.count(in_node->GetType()) != 0)); | |||||
| (kWhileOpTypes.count(in_node->GetType()) != 0)); | |||||
| if (insert_flag) { | if (insert_flag) { | ||||
| GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | ||||
| std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | ||||
| @@ -32,5 +32,8 @@ REGISTER_OP_CREATOR(Assign, HostOp); | |||||
| REGISTER_OP_CREATOR(RandomUniform, HostOp); | REGISTER_OP_CREATOR(RandomUniform, HostOp); | ||||
| REGISTER_OP_CREATOR(Add, HostOp); | REGISTER_OP_CREATOR(Add, HostOp); | ||||
| REGISTER_OP_CREATOR(Mul, HostOp); | REGISTER_OP_CREATOR(Mul, HostOp); | ||||
| REGISTER_OP_CREATOR(ConcatV2, HostOp); | |||||
| REGISTER_OP_CREATOR(Data, HostOp); | |||||
| REGISTER_OP_CREATOR(Fill, HostOp); | |||||
| } // namespace host_cpu | } // namespace host_cpu | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -59,6 +59,7 @@ Status HybridModelAsyncExecutor::Start(const std::shared_ptr<ModelListener> &lis | |||||
| run_flag_ = true; | run_flag_ = true; | ||||
| listener_ = listener; | listener_ = listener; | ||||
| future_ = std::async(std::launch::async, [&]() -> Status { | future_ = std::async(std::launch::async, [&]() -> Status { | ||||
| GetThreadLocalContext() = *executor_->GetContext()->ge_context; | |||||
| GetContext().SetSessionId(executor_->GetContext()->session_id); | GetContext().SetSessionId(executor_->GetContext()->session_id); | ||||
| return RunInternal(); | return RunInternal(); | ||||
| }); | }); | ||||
| @@ -229,7 +230,11 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy | |||||
| } | } | ||||
| GE_CHECK_GE(tensor_size, 0); | GE_CHECK_GE(tensor_size, 0); | ||||
| auto tensor_buffer = TensorBuffer::Create(allocator, tensor_size); | |||||
| AllocationAttr attr; | |||||
| if (GetContext().GetHostExecFlag()) { | |||||
| attr.SetMemType(HOST_DDR); | |||||
| } | |||||
| auto tensor_buffer = TensorBuffer::Create(allocator, tensor_size, &attr); | |||||
| GE_CHECK_NOTNULL(tensor_buffer); | GE_CHECK_NOTNULL(tensor_buffer); | ||||
| args.inputs.emplace_back(std::shared_ptr<TensorBuffer>(tensor_buffer.release())); | args.inputs.emplace_back(std::shared_ptr<TensorBuffer>(tensor_buffer.release())); | ||||
| @@ -772,7 +772,12 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ | |||||
| var_name.c_str(), | var_name.c_str(), | ||||
| hybrid_model_.GetSessionId()); | hybrid_model_.GetSessionId()); | ||||
| uint8_t *dev_mem = var_manager_->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | |||||
| rtMemType_t memory_type = RT_MEMORY_HBM; | |||||
| uint32_t mem_type = 0; | |||||
| if (AttrUtils::GetInt(var_node->GetOpDesc(), ATTR_OUTPUT_MEMORY_TYPE, mem_type) && (mem_type == 1)) { | |||||
| memory_type = RT_MEMORY_RDMA_HBM; | |||||
| } | |||||
| uint8_t *dev_mem = var_manager_->GetVarMemoryAddr(var_logic, memory_type); | |||||
| if (dev_mem == nullptr) { | if (dev_mem == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
| "Failed to copy var %s from device, cant not get " | "Failed to copy var %s from device, cant not get " | ||||
| @@ -15,23 +15,25 @@ | |||||
| */ | */ | ||||
| #include "hybrid/node_executor/hccl/hccl_node_executor.h" | #include "hybrid/node_executor/hccl/hccl_node_executor.h" | ||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
| #include "graph/runtime_inference_context.h" | #include "graph/runtime_inference_context.h" | ||||
| #include "hccl/hcom.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | |||||
| namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kVarTableDims = 2; | |||||
| const size_t kVarTableRowCnt = 3; | |||||
| const size_t kVarTableIdxAddr = 1; | |||||
| const size_t kVarTableIdxLen = 2; | |||||
| constexpr size_t kVarTableDims = 2; | |||||
| constexpr size_t kVarTableRowCnt = 3; | |||||
| constexpr size_t kVarTableIdxAddr = 1; | |||||
| constexpr size_t kVarTableIdxLen = 2; | |||||
| const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; | |||||
| const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; | |||||
| const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor); | REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor); | ||||
| @@ -142,11 +144,22 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||||
| GE_CHECK_NOTNULL(peer_node->GetOpDesc()); | GE_CHECK_NOTNULL(peer_node->GetOpDesc()); | ||||
| remote_index_ = {peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx()}; | remote_index_ = {peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx()}; | ||||
| if (node_item.node->GetType() == HCOMREMOTEREAD) { | |||||
| if (kRdmaReadTypes.count(node_item.node->GetType()) > 0) { | |||||
| local_index_ = 0; | local_index_ = 0; | ||||
| } else { | } else { | ||||
| local_index_ = op_desc->GetInputIndexByName("local"); | local_index_ = op_desc->GetInputIndexByName("local"); | ||||
| } | } | ||||
| int32_t offset_idx = node_item.op_desc->GetInputIndexByName("local_offset"); | |||||
| if ((offset_idx != -1) && (node_item.op_desc->GetInputDescPtr(offset_idx) != nullptr)) { | |||||
| skip_flag_ = true; | |||||
| GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)); | |||||
| GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()); | |||||
| GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()); | |||||
| GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc()); | |||||
| offset_index_ = { | |||||
| node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc()->GetId(), | |||||
| node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetIdx() }; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -158,8 +171,13 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | ||||
| auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| GELOGE(FAILED, "Tensor data is nullptr."); | |||||
| return FAILED; | |||||
| if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } else { | |||||
| GELOGE(FAILED, "Tensor data is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | ||||
| if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | ||||
| @@ -183,30 +201,63 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | ||||
| GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | ||||
| } | } | ||||
| } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
| } | } | ||||
| TensorValue *tv; | TensorValue *tv; | ||||
| if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
| tv = context.MutableOutput(0); | |||||
| if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| tv = context.MutableOutput(local_index_); | |||||
| } else { | } else { | ||||
| tv = context.MutableInput(local_index_); | tv = context.MutableInput(local_index_); | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(tv); | GE_CHECK_NOTNULL(tv); | ||||
| auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | |||||
| auto row_num = dims.front(); | auto row_num = dims.front(); | ||||
| addr_infos.resize(row_num); | addr_infos.resize(row_num); | ||||
| auto device_len = tv->GetSize() / row_num; | |||||
| if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { | |||||
| GELOGE(FAILED, "Local embedding length is out of range."); | |||||
| return FAILED; | |||||
| } | |||||
| if (skip_flag_) { | |||||
| int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | |||||
| GE_CHECK_NOTNULL(context.GetNodeItem().op_desc->GetInputDescPtr(offset_idx)); | |||||
| auto data_type = context.GetNodeItem().op_desc->GetInputDesc(offset_idx).GetDataType(); | |||||
| Tensor offset_tensor; | |||||
| GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor)) | |||||
| if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) { | |||||
| GELOGE(PARAM_INVALID, "num of offset and remote addr mismatch, offset size=%zu, remote_addr size=%lld, dtype=%s", | |||||
| offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| for (auto idx = 0; idx < row_num; ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, | |||||
| device_len}; | |||||
| local_addr += device_len; | |||||
| auto addr_offset = reinterpret_cast<uint64_t *>(offset_tensor.GetData()); | |||||
| GE_CHECK_NOTNULL(addr_offset); | |||||
| auto base_addr = reinterpret_cast<float *>(tv->MutableData()); | |||||
| GE_CHECK_NOTNULL(base_addr); | |||||
| for (auto idx = 0; idx < row_num; idx++) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt) | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| addr_infos[idx] = { static_cast<uint32_t>(data[line_idx]), | |||||
| data[line_idx + kVarTableIdxAddr], | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(base_addr + addr_offset[idx])), | |||||
| data[line_idx + kVarTableIdxLen] }; | |||||
| } | |||||
| } else { | |||||
| auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | |||||
| auto device_len = tv->GetSize() / row_num; | |||||
| if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { | |||||
| GELOGE(FAILED, "Local embedding length is out of range, expect %lld, but %lld exactly.", | |||||
| data[kVarTableIdxLen], device_len); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto idx = 0; idx < row_num; ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt) | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| addr_infos[idx] = { static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, | |||||
| device_len }; | |||||
| local_addr += device_len; | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -226,6 +277,10 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| } | } | ||||
| vector<HcomRemoteAccessAddrInfo> addr_infos; | vector<HcomRemoteAccessAddrInfo> addr_infos; | ||||
| GE_CHK_STATUS_RET(ExtractTensor(context, addr_infos)); | GE_CHK_STATUS_RET(ExtractTensor(context, addr_infos)); | ||||
| if (addr_infos.empty()) { | |||||
| done_callback(); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto callback = [this](HcclResult status) { | auto callback = [this](HcclResult status) { | ||||
| if (status != HCCL_SUCCESS) { | if (status != HCCL_SUCCESS) { | ||||
| @@ -235,6 +290,11 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| this->cond_.notify_all(); | this->cond_.notify_all(); | ||||
| GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
| }; | }; | ||||
| std::string executor_type = context.GetNodeItem().NodeType(); | |||||
| if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| executor_type = context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD ? HCOMREMOTEREAD : HCOMREMOTEWRITE; | |||||
| } | |||||
| HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
| if (hccl_ret != HCCL_SUCCESS) { | if (hccl_ret != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret); | GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret); | ||||
| @@ -262,7 +322,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const | |||||
| GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed."); | GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed."); | ||||
| // allocate output mem, output mem or remote read will be calculated when node execute. | // allocate output mem, output mem or remote read will be calculated when node execute. | ||||
| if (context.GetNodeItem().NodeType() != HCOMREMOTEREAD) { | |||||
| if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) == 0) { | |||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed."); | GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed."); | ||||
| } | } | ||||
| @@ -274,7 +334,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const | |||||
| Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
| GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str()); | GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str()); | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| if (node->GetType() == HCOMREMOTEREAD || node->GetType() == HCOMREMOTEWRITE) { | |||||
| if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | |||||
| task = MakeShared<RdmaNodeTask>(); | task = MakeShared<RdmaNodeTask>(); | ||||
| } else { | } else { | ||||
| task = MakeShared<HcclNodeTask>(); | task = MakeShared<HcclNodeTask>(); | ||||
| @@ -55,9 +55,11 @@ class RdmaNodeTask : public NodeTask { | |||||
| private: | private: | ||||
| Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | ||||
| std::pair<int64_t, int64_t> remote_index_; | std::pair<int64_t, int64_t> remote_index_; | ||||
| std::pair<int64_t, int64_t> offset_index_; | |||||
| int32_t local_index_ = 0; | int32_t local_index_ = 0; | ||||
| std::mutex hccl_mutex_; | std::mutex hccl_mutex_; | ||||
| std::condition_variable cond_; | std::condition_variable cond_; | ||||
| bool skip_flag_; | |||||
| }; | }; | ||||
| class HcclNodeExecutor : public NodeExecutor { | class HcclNodeExecutor : public NodeExecutor { | ||||
| @@ -29,8 +29,6 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace host_cpu { | namespace host_cpu { | ||||
| Status AssignKernel::Compute(TaskContext& context) { | Status AssignKernel::Compute(TaskContext& context) { | ||||
| GELOGI("[%s] compute begin.", node_->GetName().c_str()); | |||||
| auto ref_tensor = context.MutableInput(kAssignRefInputIndex); | auto ref_tensor = context.MutableInput(kAssignRefInputIndex); | ||||
| GE_CHECK_NOTNULL(ref_tensor); | GE_CHECK_NOTNULL(ref_tensor); | ||||
| const auto value_tensor = context.GetInput(kAssignValueInputIndex); | const auto value_tensor = context.GetInput(kAssignValueInputIndex); | ||||
| @@ -50,7 +48,7 @@ Status AssignKernel::Compute(TaskContext& context) { | |||||
| GE_CHK_STATUS_RET(context.SetOutput(kAssignRefOutputIndex, *ref_tensor), | GE_CHK_STATUS_RET(context.SetOutput(kAssignRefOutputIndex, *ref_tensor), | ||||
| "[%s] Failed to set output.", context.GetNodeName()); | "[%s] Failed to set output.", context.GetNodeName()); | ||||
| GELOGI("[%s] compute success.", node_->GetName().c_str()); | |||||
| GELOGD("[%s] compute success.", node_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "hybrid/node_executor/host_cpu/kernel/data_kernel.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "hybrid/node_executor/host_cpu/kernel_factory.h" | |||||
| namespace { | |||||
| constexpr size_t kDataInputIndex = 0; | |||||
| constexpr size_t kDataOutputIndex = 0; | |||||
| } | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| namespace host_cpu { | |||||
| Status DataKernel::Compute(TaskContext& context) { | |||||
| auto input = context.MutableInput(kDataInputIndex); | |||||
| GE_CHECK_NOTNULL(input); | |||||
| GE_CHK_STATUS_RET(context.SetOutput(kDataOutputIndex, *input), "[%s] Failed to set output.", context.GetNodeName()) | |||||
| GELOGD("[%s] compute success.", node_->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_KERNEL_CREATOR(Data, DataKernel); | |||||
| } // namespace host_cpu | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HYBRID_HOST_CPU_KERNEL_DATA_KERNEL_H_ | |||||
| #define GE_HYBRID_HOST_CPU_KERNEL_DATA_KERNEL_H_ | |||||
| #include "hybrid/node_executor/host_cpu/kernel/kernel.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| namespace host_cpu { | |||||
| class DataKernel : public Kernel { | |||||
| public: | |||||
| DataKernel(const NodePtr &node) : Kernel(node) {} | |||||
| ~DataKernel() override = default; | |||||
| DataKernel &operator=(const DataKernel &op) = delete; | |||||
| DataKernel(const DataKernel &op) = delete; | |||||
| /** | |||||
| * @brief compute for node_task. | |||||
| * @return result | |||||
| */ | |||||
| Status Compute(TaskContext& context) override; | |||||
| }; | |||||
| } // namespace host_cpu | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| #endif // GE_HYBRID_HOST_CPU_KERNEL_DATA_KERNEL_H_ | |||||
| @@ -23,7 +23,7 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace host_cpu { | namespace host_cpu { | ||||
| Status NoOpKernel::Compute(TaskContext& context) { | Status NoOpKernel::Compute(TaskContext& context) { | ||||
| GELOGI("[%s] no need to compute.", node_->GetName().c_str()); | |||||
| GELOGD("[%s] no need to compute.", node_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -30,8 +30,6 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace host_cpu { | namespace host_cpu { | ||||
| Status RandomUniformKernel::Compute(TaskContext& context) { | Status RandomUniformKernel::Compute(TaskContext& context) { | ||||
| GELOGI("[%s] compute begin.", node_->GetName().c_str()); | |||||
| int64_t seed = 0; | int64_t seed = 0; | ||||
| int64_t seed2 = 0; | int64_t seed2 = 0; | ||||
| (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed", seed); | (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed", seed); | ||||
| @@ -66,7 +64,7 @@ Status RandomUniformKernel::Compute(TaskContext& context) { | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| GELOGI("[%s] compute success.", node_->GetName().c_str()); | |||||
| GELOGD("[%s] compute success.", node_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -23,8 +23,6 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace host_cpu { | namespace host_cpu { | ||||
| Status VariableKernel::Compute(TaskContext& context) { | Status VariableKernel::Compute(TaskContext& context) { | ||||
| GELOGI("[%s] compute begin.", node_->GetName().c_str()); | |||||
| auto tensor = context.GetVariable(node_->GetName()); | auto tensor = context.GetVariable(node_->GetName()); | ||||
| if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
| GELOGE(PARAM_INVALID, "tensor is NULL."); | GELOGE(PARAM_INVALID, "tensor is NULL."); | ||||
| @@ -32,7 +30,7 @@ Status VariableKernel::Compute(TaskContext& context) { | |||||
| } | } | ||||
| // Constant & Variable Op has and only has one output | // Constant & Variable Op has and only has one output | ||||
| GE_CHK_STATUS_RET(context.SetOutput(0, *tensor), "[%s] Failed to set output.", context.GetNodeName()); | GE_CHK_STATUS_RET(context.SetOutput(0, *tensor), "[%s] Failed to set output.", context.GetNodeName()); | ||||
| GELOGI("[%s] compute success.", node_->GetName().c_str()); | |||||
| GELOGD("[%s] compute success.", node_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -437,6 +437,7 @@ REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | |||||
| REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | ||||
| REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
| @@ -370,7 +370,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREDUCESC | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMSEND; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMSEND; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMRECEIVE; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMRECEIVE; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREAD; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREAD; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREFREAD; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEWRITE; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEWRITE; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTESCATTERWRITE; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARASSIGN; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARASSIGN; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARISINITIALIZEDOP; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARISINITIALIZEDOP; | ||||
| @@ -589,6 +589,7 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES | |||||
| #"graph/graph_load_unittest.cc" | #"graph/graph_load_unittest.cc" | ||||
| "graph/ge_executor_unittest.cc" | "graph/ge_executor_unittest.cc" | ||||
| "graph/load/model_helper_unittest.cc" | "graph/load/model_helper_unittest.cc" | ||||
| "graph/load/model_utils_unittest.cc" | |||||
| ) | ) | ||||
| set(PASS_TEST_FILES | set(PASS_TEST_FILES | ||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/load/new_model_manager/model_utils.h" | |||||
| #include "graph/manager/graph_var_manager.h" | |||||
| using namespace std; | |||||
| namespace ge { | |||||
| class UtestModelUtils : public testing::Test { | |||||
| protected: | |||||
| void TearDown() {} | |||||
| }; | |||||
| // test ModelUtils::GetVarAddr | |||||
| TEST_F(UtestModelUtils, get_var_addr_hbm) { | |||||
| uint8_t test = 2; | |||||
| uint8_t *pf = &test; | |||||
| RuntimeParam runtime_param; | |||||
| runtime_param.session_id = 0; | |||||
| runtime_param.logic_var_base = 0; | |||||
| runtime_param.var_base = pf; | |||||
| runtime_param.var_size = 16; | |||||
| int64_t offset = 8; | |||||
| EXPECT_EQ(VarManager::Instance(runtime_param.session_id)->Init(0, 0, 0, 0), SUCCESS); | |||||
| EXPECT_NE(VarManager::Instance(runtime_param.session_id)->var_resource_, nullptr); | |||||
| VarManager::Instance(runtime_param.session_id)->var_resource_->var_offset_map_[offset] = RT_MEMORY_HBM; | |||||
| std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>("test", "test"); | |||||
| uint8_t *var_addr = nullptr; | |||||
| EXPECT_EQ(ModelUtils::GetVarAddr(runtime_param, op_desc, offset, var_addr), SUCCESS); | |||||
| EXPECT_EQ(runtime_param.var_base + offset - runtime_param.logic_var_base, var_addr); | |||||
| VarManager::Instance(runtime_param.session_id)->Destory(); | |||||
| } | |||||
| TEST_F(UtestModelUtils, get_var_addr_rdma_hbm) { | |||||
| uint8_t test = 2; | |||||
| uint8_t *pf = &test; | |||||
| RuntimeParam runtime_param; | |||||
| runtime_param.session_id = 0; | |||||
| runtime_param.logic_var_base = 0; | |||||
| runtime_param.var_base = pf; | |||||
| int64_t offset = 8; | |||||
| EXPECT_EQ(VarManager::Instance(runtime_param.session_id)->Init(0, 0, 0, 0), SUCCESS); | |||||
| EXPECT_NE(VarManager::Instance(runtime_param.session_id)->var_resource_, nullptr); | |||||
| VarManager::Instance(runtime_param.session_id)->var_resource_->var_offset_map_[offset] = RT_MEMORY_RDMA_HBM; | |||||
| std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>("test", "test"); | |||||
| uint8_t *var_addr = nullptr; | |||||
| EXPECT_EQ(ModelUtils::GetVarAddr(runtime_param, op_desc, offset, var_addr), SUCCESS); | |||||
| EXPECT_EQ(reinterpret_cast<uint8_t *>(offset), var_addr); | |||||
| VarManager::Instance(runtime_param.session_id)->Destory(); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -34,6 +34,7 @@ extern "C" { | |||||
| */ | */ | ||||
| #define RT_MEMORY_DEFAULT ((uint32_t)0x0) // default memory on device | #define RT_MEMORY_DEFAULT ((uint32_t)0x0) // default memory on device | ||||
| #define RT_MEMORY_HBM ((uint32_t)0x2) // HBM memory on device | #define RT_MEMORY_HBM ((uint32_t)0x2) // HBM memory on device | ||||
| #define RT_MEMORY_RDMA_HBM ((uint32_t)0x3) // RDMA-HBM memory on device | |||||
| #define RT_MEMORY_DDR ((uint32_t)0x4) // DDR memory on device | #define RT_MEMORY_DDR ((uint32_t)0x4) // DDR memory on device | ||||
| #define RT_MEMORY_SPM ((uint32_t)0x8) // shared physical memory on device | #define RT_MEMORY_SPM ((uint32_t)0x8) // shared physical memory on device | ||||
| #define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device | #define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device | ||||