@@ -32,7 +32,6 @@ | |||
#include "graph/ge_attr_value.h" | |||
#include "graph/ge_context.h" | |||
#include "external/graph/ge_error_codes.h" | |||
#include "graph/manager/graph_mem_allocator.h" | |||
#include "graph/manager/graph_var_manager.h" | |||
#include "graph/optimize/common/params.h" | |||
#include "external/graph/types.h" | |||
@@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na | |||
return SUCCESS; | |||
} | |||
ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
GE_CHECK_NOTNULL(base_ptr); | |||
GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str()); | |||
VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||
uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset; | |||
return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr, | |||
var_broadcast_info.input_size, session_id_); | |||
} | |||
ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); | |||
VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||
// subgraph base_ptr could be nullptr, task it as base 0 | |||
uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset; | |||
return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name, | |||
var_tensor_desc, session_id_); | |||
} | |||
ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name, | |||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); | |||
} | |||
bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | |||
rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | |||
@@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) { | |||
return var_resource_->IsVarExist(var_name); | |||
} | |||
ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
uint8_t *base_ptr) { | |||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
if (var_resource_ == nullptr) { | |||
GELOGW("VarManager has not been init."); | |||
return ge::INTERNAL_ERROR; | |||
} | |||
return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr); | |||
} | |||
ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { | |||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); | |||
@@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt | |||
return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | |||
} | |||
ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
if (var_resource_ == nullptr) { | |||
GELOGW("VarManager has not been init."); | |||
return ge::INTERNAL_ERROR; | |||
} | |||
return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr); | |||
} | |||
bool VarManager::IsVarAddr(const int64_t &offset) { | |||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||
if (var_resource_ == nullptr) { | |||
@@ -118,15 +118,6 @@ class VarResource { | |||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||
ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||
ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||
ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
uint8_t *base_ptr); | |||
Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | |||
if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { | |||
GELOGW("Var name: %s has already set.", var_name.c_str()); | |||
@@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | |||
ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
uint8_t *base_ptr); | |||
ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | |||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||
ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||
uint8_t *base_ptr); | |||
ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | |||
ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | |||
@@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, | |||
return SUCCESS; | |||
} | |||
} // namespace | |||
Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | |||
GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); | |||
uint8_t *src_host_addr = nullptr; | |||
int64_t src_addr_size = 0; | |||
GE_MAKE_GUARD_RTMEM(src_host_addr); | |||
GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); | |||
GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); | |||
GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, | |||
"[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", | |||
src_addr_size, dst_addr_size); | |||
GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||
return SUCCESS; | |||
} | |||
Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||
GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); | |||
uint8_t *host_addr = nullptr; | |||
GE_MAKE_GUARD_RTMEM(host_addr); | |||
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size)); | |||
GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||
GE_CHK_STATUS_RET( | |||
SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id)); | |||
return SUCCESS; | |||
} | |||
Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { | |||
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); | |||
uint8_t *src_addr = nullptr; | |||
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); | |||
uint8_t *mem_addr = | |||
src_addr - | |||
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||
static_cast<int64_t>( | |||
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size)); | |||
GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||
GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size); | |||
return SUCCESS; | |||
} | |||
Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||
uint8_t *dst_addr = nullptr; | |||
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); | |||
uint8_t *mem_addr = | |||
dst_addr - | |||
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||
static_cast<int64_t>( | |||
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||
GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||
GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); | |||
return SUCCESS; | |||
} | |||
Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||
uint64_t session_id, | |||
rtContext_t context, | |||
@@ -29,11 +29,6 @@ | |||
namespace ge { | |||
class TransVarDataUtils { | |||
public: | |||
static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_); | |||
static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||
static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, | |||
uint64_t session_id, | |||
rtContext_t context, | |||
@@ -41,12 +36,6 @@ class TransVarDataUtils { | |||
uint32_t thread_num = 16); | |||
static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | |||
private: | |||
static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||
uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); | |||
static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||
}; | |||
} // namespace ge | |||