@@ -32,7 +32,6 @@ | |||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "external/graph/ge_error_codes.h" | #include "external/graph/ge_error_codes.h" | ||||
#include "graph/manager/graph_mem_allocator.h" | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
@@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
GE_CHECK_NOTNULL(base_ptr); | |||||
GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str()); | |||||
VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||||
uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset; | |||||
return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr, | |||||
var_broadcast_info.input_size, session_id_); | |||||
} | |||||
ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); | |||||
VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||||
// subgraph base_ptr could be nullptr, task it as base 0 | |||||
uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset; | |||||
return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name, | |||||
var_tensor_desc, session_id_); | |||||
} | |||||
ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name, | |||||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); | |||||
} | |||||
bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | ||||
rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | ||||
@@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) { | |||||
return var_resource_->IsVarExist(var_name); | return var_resource_->IsVarExist(var_name); | ||||
} | } | ||||
ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
uint8_t *base_ptr) { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
if (var_resource_ == nullptr) { | |||||
GELOGW("VarManager has not been init."); | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr); | |||||
} | |||||
ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { | ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); | GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); | ||||
@@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt | |||||
return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | ||||
} | } | ||||
ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
if (var_resource_ == nullptr) { | |||||
GELOGW("VarManager has not been init."); | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr); | |||||
} | |||||
bool VarManager::IsVarAddr(const int64_t &offset) { | bool VarManager::IsVarAddr(const int64_t &offset) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
if (var_resource_ == nullptr) { | if (var_resource_ == nullptr) { | ||||
@@ -118,15 +118,6 @@ class VarResource { | |||||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ||||
ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||||
ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||||
ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
uint8_t *base_ptr); | |||||
Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | ||||
if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { | if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { | ||||
GELOGW("Var name: %s has already set.", var_name.c_str()); | GELOGW("Var name: %s has already set.", var_name.c_str()); | ||||
@@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ||||
ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
uint8_t *base_ptr); | |||||
ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ||||
ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
uint8_t *base_ptr); | |||||
ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | ||||
ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ||||
@@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | |||||
GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); | |||||
uint8_t *src_host_addr = nullptr; | |||||
int64_t src_addr_size = 0; | |||||
GE_MAKE_GUARD_RTMEM(src_host_addr); | |||||
GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); | |||||
GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); | |||||
GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, | |||||
"[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", | |||||
src_addr_size, dst_addr_size); | |||||
GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||||
GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); | |||||
uint8_t *host_addr = nullptr; | |||||
GE_MAKE_GUARD_RTMEM(host_addr); | |||||
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size)); | |||||
GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
GE_CHK_STATUS_RET( | |||||
SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id)); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { | |||||
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); | |||||
uint8_t *src_addr = nullptr; | |||||
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); | |||||
uint8_t *mem_addr = | |||||
src_addr - | |||||
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||||
static_cast<int64_t>( | |||||
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||||
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size)); | |||||
GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||||
uint8_t *dst_addr = nullptr; | |||||
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); | |||||
uint8_t *mem_addr = | |||||
dst_addr - | |||||
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||||
static_cast<int64_t>( | |||||
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||||
GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | ||||
uint64_t session_id, | uint64_t session_id, | ||||
rtContext_t context, | rtContext_t context, | ||||
@@ -29,11 +29,6 @@ | |||||
namespace ge { | namespace ge { | ||||
class TransVarDataUtils { | class TransVarDataUtils { | ||||
public: | public: | ||||
static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_); | |||||
static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||||
static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, | static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, | ||||
uint64_t session_id, | uint64_t session_id, | ||||
rtContext_t context, | rtContext_t context, | ||||
@@ -41,12 +36,6 @@ class TransVarDataUtils { | |||||
uint32_t thread_num = 16); | uint32_t thread_num = 16); | ||||
static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | ||||
private: | |||||
static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); | |||||
static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||