diff --git a/ge/graph/manager/host_mem_manager.cc b/ge/graph/manager/host_mem_manager.cc index 60a7586d..d77bf3b2 100644 --- a/ge/graph/manager/host_mem_manager.cc +++ b/ge/graph/manager/host_mem_manager.cc @@ -104,15 +104,15 @@ Status HostMemManager::MallocSharedMemory(SharedMemInfo &mem_info) { return SUCCESS; } -Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size) { +bool HostMemManager::QueryVarMemInfo(const string &op_name, SharedMemInfo &mem_info) { std::lock_guard lock(mutex_); - if (var_memory_base_map_.find(op_name) == var_memory_base_map_.end()) { - GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); - return INTERNAL_ERROR; + auto it = var_memory_base_map_.find(op_name); + if (it == var_memory_base_map_.end()) { + GELOGW("Host memory for node [%s] not found.", op_name.c_str()); + return false; } - base_addr = static_cast(reinterpret_cast(var_memory_base_map_[op_name].device_address)); - data_size = var_memory_base_map_[op_name].mem_size; - return SUCCESS; + mem_info = it->second; + return true; } string HostMemManager::OpNameToShmName(const string &op_name) { diff --git a/ge/graph/manager/host_mem_manager.h b/ge/graph/manager/host_mem_manager.h index be3237c3..84d5aebe 100644 --- a/ge/graph/manager/host_mem_manager.h +++ b/ge/graph/manager/host_mem_manager.h @@ -66,7 +66,7 @@ class HostMemManager { Status Initialize(); void Finalize() noexcept; Status MallocSharedMemory(SharedMemInfo &mem_nfo); - Status QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size); + bool QueryVarMemInfo(const string &op_name, SharedMemInfo &mem_info); private: static string OpNameToShmName(const string &op_name); diff --git a/ge/graph/manager/memory_api.cc b/ge/graph/manager/memory_api.cc index 0798eb51..2bd7b71f 100644 --- a/ge/graph/manager/memory_api.cc +++ b/ge/graph/manager/memory_api.cc @@ -106,7 +106,14 @@ Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uin } Status GetVarBaseAddrAndSize(const string &var_name, uint64_t &base_addr, uint64_t &var_size) { - GELOGD("GetVarBaseAddrAndSize in"); - return HostMemManager::Instance().QueryVarMemInfo(var_name, base_addr, var_size); + GELOGD("GetVarBaseAddrAndSize in, var name:[%s]", var_name.c_str()); + SharedMemInfo mem_info; + if (!HostMemManager::Instance().QueryVarMemInfo(var_name, mem_info)) { + GELOGE(FAILED, "Get addr and size failed, name:[%s]", var_name.c_str()); + return FAILED; + } + base_addr = static_cast(reinterpret_cast(mem_info.host_aligned_ptr->Get())); + var_size = mem_info.mem_size; + return SUCCESS; } } // namespace ge diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index b329aaa6..ef738b1c 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1012,20 +1012,25 @@ Status HybridModelBuilder::InitVariableTensors() { GELOGE(INTERNAL_ERROR, "Calculate variable size failed, node name:%s", it.first.c_str()); return INTERNAL_ERROR; } - SharedMemInfo mem_info(it.first, tensor_size); - if (HostMemManager::Instance().MallocSharedMemory(mem_info) != SUCCESS) { - GELOGE(GE_GRAPH_MALLOC_FAILED, "Host variable [%s] malloc failed.", it.first.c_str()); - return GE_GRAPH_MALLOC_FAILED; - } - if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(mem_info.host_aligned_ptr, - tensor_size) == nullptr) { - GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); + + // Host variable will be assigned to allocated shared memory first. + SharedMemInfo mem_info; + void *mem_addr = nullptr; + if (HostMemManager::Instance().QueryVarMemInfo(it.first, mem_info)) { + mem_addr = const_cast(MemManager::Instance().HostMemInstance(RT_MEMORY_HBM) + .Malloc(mem_info.host_aligned_ptr, tensor_size)); + } else { + mem_addr = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(tensor_size); + } + + if (mem_addr == nullptr) { + REPORT_INNER_ERROR("E19999", "[Malloc][HostMem] for variable [%s] failed.", it.first.c_str()); + GELOGE(MEMALLOC_FAILED, "[Malloc][HostMem] for variable [%s] failed.", it.first.c_str()); return MEMALLOC_FAILED; } GELOGD("Host variable [%s] malloc success, size=%ld.", it.first.c_str(), tensor_size); - std::unique_ptr tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(), - tensor_size)); + std::unique_ptr tensor(new (std::nothrow) TensorValue(mem_addr, tensor_size)); GE_CHECK_NOTNULL(tensor); hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor)); }