Browse Source

shared memory optimize for widedeep

pull/1762/head
isaacxr 4 years ago
parent
commit
239b51b5f1
4 changed files with 32 additions and 20 deletions
  1. +7
    -7
      ge/graph/manager/host_mem_manager.cc
  2. +1
    -1
      ge/graph/manager/host_mem_manager.h
  3. +9
    -2
      ge/graph/manager/memory_api.cc
  4. +15
    -10
      ge/hybrid/model/hybrid_model_builder.cc

+ 7
- 7
ge/graph/manager/host_mem_manager.cc View File

@@ -104,15 +104,15 @@ Status HostMemManager::MallocSharedMemory(SharedMemInfo &mem_info) {
return SUCCESS; return SUCCESS;
} }


Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size) {
bool HostMemManager::QueryVarMemInfo(const string &op_name, SharedMemInfo &mem_info) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_memory_base_map_.find(op_name) == var_memory_base_map_.end()) {
GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str());
return INTERNAL_ERROR;
auto it = var_memory_base_map_.find(op_name);
if (it == var_memory_base_map_.end()) {
GELOGW("Host memory for node [%s] not found.", op_name.c_str());
return false;
} }
base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address));
data_size = var_memory_base_map_[op_name].mem_size;
return SUCCESS;
mem_info = it->second;
return true;
} }


string HostMemManager::OpNameToShmName(const string &op_name) { string HostMemManager::OpNameToShmName(const string &op_name) {


+ 1
- 1
ge/graph/manager/host_mem_manager.h View File

@@ -66,7 +66,7 @@ class HostMemManager {
Status Initialize(); Status Initialize();
void Finalize() noexcept; void Finalize() noexcept;
Status MallocSharedMemory(SharedMemInfo &mem_nfo); Status MallocSharedMemory(SharedMemInfo &mem_nfo);
Status QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size);
bool QueryVarMemInfo(const string &op_name, SharedMemInfo &mem_info);


private: private:
static string OpNameToShmName(const string &op_name); static string OpNameToShmName(const string &op_name);


+ 9
- 2
ge/graph/manager/memory_api.cc View File

@@ -106,7 +106,14 @@ Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uin
} }


Status GetVarBaseAddrAndSize(const string &var_name, uint64_t &base_addr, uint64_t &var_size) { Status GetVarBaseAddrAndSize(const string &var_name, uint64_t &base_addr, uint64_t &var_size) {
GELOGD("GetVarBaseAddrAndSize in");
return HostMemManager::Instance().QueryVarMemInfo(var_name, base_addr, var_size);
GELOGD("GetVarBaseAddrAndSize in, var name:[%s]", var_name.c_str());
SharedMemInfo mem_info;
if (!HostMemManager::Instance().QueryVarMemInfo(var_name, mem_info)) {
GELOGE(FAILED, "Get addr and size failed, name:[%s]", var_name.c_str());
return FAILED;
}
base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(mem_info.host_aligned_ptr->Get()));
var_size = mem_info.mem_size;
return SUCCESS;
} }
} // namespace ge } // namespace ge

+ 15
- 10
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1012,20 +1012,25 @@ Status HybridModelBuilder::InitVariableTensors() {
GELOGE(INTERNAL_ERROR, "Calculate variable size failed, node name:%s", it.first.c_str()); GELOGE(INTERNAL_ERROR, "Calculate variable size failed, node name:%s", it.first.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
SharedMemInfo mem_info(it.first, tensor_size);
if (HostMemManager::Instance().MallocSharedMemory(mem_info) != SUCCESS) {
GELOGE(GE_GRAPH_MALLOC_FAILED, "Host variable [%s] malloc failed.", it.first.c_str());
return GE_GRAPH_MALLOC_FAILED;
}
if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(mem_info.host_aligned_ptr,
tensor_size) == nullptr) {
GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed.");

// Host variable will be assigned to allocated shared memory first.
SharedMemInfo mem_info;
void *mem_addr = nullptr;
if (HostMemManager::Instance().QueryVarMemInfo(it.first, mem_info)) {
mem_addr = const_cast<void *>(MemManager::Instance().HostMemInstance(RT_MEMORY_HBM)
.Malloc(mem_info.host_aligned_ptr, tensor_size));
} else {
mem_addr = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(tensor_size);
}

if (mem_addr == nullptr) {
REPORT_INNER_ERROR("E19999", "[Malloc][HostMem] for variable [%s] failed.", it.first.c_str());
GELOGE(MEMALLOC_FAILED, "[Malloc][HostMem] for variable [%s] failed.", it.first.c_str());
return MEMALLOC_FAILED; return MEMALLOC_FAILED;
} }
GELOGD("Host variable [%s] malloc success, size=%ld.", it.first.c_str(), tensor_size); GELOGD("Host variable [%s] malloc success, size=%ld.", it.first.c_str(), tensor_size);


std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(),
tensor_size));
std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_addr, tensor_size));
GE_CHECK_NOTNULL(tensor); GE_CHECK_NOTNULL(tensor);
hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor)); hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor));
} }


Loading…
Cancel
Save