From a892b2bf901e9939e49d8125014dbaa599519902 Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Tue, 19 Jan 2021 12:35:38 +0800 Subject: [PATCH] cache support --- .../load/new_model_manager/model_utils.cc | 25 ++++++++++++------- ge/graph/manager/graph_var_manager.cc | 4 +-- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/ge/graph/load/new_model_manager/model_utils.cc b/ge/graph/load/new_model_manager/model_utils.cc index efd8c619..d9a9f3ca 100755 --- a/ge/graph/load/new_model_manager/model_utils.cc +++ b/ge/graph/load/new_model_manager/model_utils.cc @@ -379,17 +379,24 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co /// Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset, uint8_t *&var_addr) { - if (ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset) == RT_MEMORY_RDMA_HBM) { - if (offset < 0) { - GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast(offset)); + rtMemType_t mem_type = ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset); + switch (mem_type) { + case RT_MEMORY_RDMA_HBM: + if (offset < 0) { + GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast(offset)); + return PARAM_INVALID; + } + var_addr = reinterpret_cast(offset); + break; + case RT_MEMORY_HBM: + VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); + var_addr = model_param.var_base + offset - model_param.logic_var_base; + break; + default: + GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type); return PARAM_INVALID; - } - var_addr = reinterpret_cast(offset); - GE_CHECK_NOTNULL(var_addr); - } else { - VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); - var_addr = model_param.var_base + offset - model_param.logic_var_base; } + GE_CHECK_NOTNULL(var_addr); return SUCCESS; } diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index 928c893f..8a829d47 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -212,7 +212,7 @@ rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { if (var_offset_map_.count(offset) > 0) { return var_offset_map_[offset]; } - return RT_MEMORY_HBM; + return RT_MEMORY_RESERVED; } VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) { @@ -660,7 +660,7 @@ rtMemType_t VarManager::GetVarMemType(const int64_t &offset) { std::lock_guard lock(mutex_); if (var_resource_ == nullptr) { GELOGW("VarManager has not been init."); - return RT_MEMORY_HBM; + return RT_MEMORY_RESERVED; } return var_resource_->GetVarMemType(offset); }