| @@ -32,7 +32,6 @@ | |||
| #include "graph/ge_attr_value.h" | |||
| #include "graph/ge_context.h" | |||
| #include "external/graph/ge_error_codes.h" | |||
| #include "graph/manager/graph_mem_allocator.h" | |||
| #include "graph/manager/graph_var_manager.h" | |||
| #include "graph/optimize/common/params.h" | |||
| #include "external/graph/types.h" | |||
| @@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na | |||
| return SUCCESS; | |||
| } | |||
| ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
| GE_CHECK_NOTNULL(base_ptr); | |||
| GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str()); | |||
| VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||
| uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset; | |||
| return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr, | |||
| var_broadcast_info.input_size, session_id_); | |||
| } | |||
| ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
| GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); | |||
| VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||
| // subgraph base_ptr could be nullptr, task it as base 0 | |||
| uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset; | |||
| return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name, | |||
| var_tensor_desc, session_id_); | |||
| } | |||
| ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name, | |||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
| return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); | |||
| } | |||
| bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | |||
| rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | |||
| @@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) { | |||
| return var_resource_->IsVarExist(var_name); | |||
| } | |||
| ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
| uint8_t *base_ptr) { | |||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
| if (var_resource_ == nullptr) { | |||
| GELOGW("VarManager has not been init."); | |||
| return ge::INTERNAL_ERROR; | |||
| } | |||
| return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr); | |||
| } | |||
| ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { | |||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
| GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); | |||
| @@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt | |||
| return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | |||
| } | |||
| ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
| if (var_resource_ == nullptr) { | |||
| GELOGW("VarManager has not been init."); | |||
| return ge::INTERNAL_ERROR; | |||
| } | |||
| return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr); | |||
| } | |||
| bool VarManager::IsVarAddr(const int64_t &offset) { | |||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
| if (var_resource_ == nullptr) { | |||
| @@ -118,15 +118,6 @@ class VarResource { | |||
| ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||
| ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||
| ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||
| ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
| uint8_t *base_ptr); | |||
| Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | |||
| if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { | |||
| GELOGW("Var name: %s has already set.", var_name.c_str()); | |||
| @@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||
| ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | |||
| ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
| uint8_t *base_ptr); | |||
| ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | |||
| ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||
| ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
| uint8_t *base_ptr); | |||
| ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | |||
| ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | |||
| @@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, | |||
| return SUCCESS; | |||
| } | |||
| } // namespace | |||
| Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
| uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | |||
| GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); | |||
| uint8_t *src_host_addr = nullptr; | |||
| int64_t src_addr_size = 0; | |||
| GE_MAKE_GUARD_RTMEM(src_host_addr); | |||
| GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); | |||
| GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); | |||
| GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, | |||
| "[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", | |||
| src_addr_size, dst_addr_size); | |||
| GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||
| return SUCCESS; | |||
| } | |||
| Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||
| GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); | |||
| uint8_t *host_addr = nullptr; | |||
| GE_MAKE_GUARD_RTMEM(host_addr); | |||
| GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size)); | |||
| GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||
| GE_CHK_STATUS_RET( | |||
| SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id)); | |||
| return SUCCESS; | |||
| } | |||
| Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
| uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { | |||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); | |||
| uint8_t *src_addr = nullptr; | |||
| GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); | |||
| uint8_t *mem_addr = | |||
| src_addr - | |||
| static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||
| static_cast<int64_t>( | |||
| reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||
| GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size)); | |||
| GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||
| GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size); | |||
| return SUCCESS; | |||
| } | |||
| Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||
| uint8_t *dst_addr = nullptr; | |||
| GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); | |||
| uint8_t *mem_addr = | |||
| dst_addr - | |||
| static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||
| static_cast<int64_t>( | |||
| reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||
| GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||
| GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); | |||
| return SUCCESS; | |||
| } | |||
| Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||
| uint64_t session_id, | |||
| rtContext_t context, | |||
| @@ -29,11 +29,6 @@ | |||
| namespace ge { | |||
| class TransVarDataUtils { | |||
| public: | |||
| static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
| uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_); | |||
| static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||
| static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, | |||
| uint64_t session_id, | |||
| rtContext_t context, | |||
| @@ -41,12 +36,6 @@ class TransVarDataUtils { | |||
| uint32_t thread_num = 16); | |||
| static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | |||
| private: | |||
| static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
| uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); | |||
| static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||
| }; | |||
| } // namespace ge | |||