Browse Source

cache support

tags/v1.2.0
chenyemeng 3 years ago
parent
commit
a892b2bf90
2 changed files with 18 additions and 11 deletions
  1. +16
    -9
      ge/graph/load/new_model_manager/model_utils.cc
  2. +2
    -2
      ge/graph/manager/graph_var_manager.cc

+ 16
- 9
ge/graph/load/new_model_manager/model_utils.cc View File

@@ -379,17 +379,24 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co
///
Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset,
uint8_t *&var_addr) {
if (ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset) == RT_MEMORY_RDMA_HBM) {
if (offset < 0) {
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset));
rtMemType_t mem_type = ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset);
switch (mem_type) {
case RT_MEMORY_RDMA_HBM:
if (offset < 0) {
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset));
return PARAM_INVALID;
}
var_addr = reinterpret_cast<uint8_t *>(offset);
break;
case RT_MEMORY_HBM:
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base);
var_addr = model_param.var_base + offset - model_param.logic_var_base;
break;
default:
GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type);
return PARAM_INVALID;
}
var_addr = reinterpret_cast<uint8_t *>(offset);
GE_CHECK_NOTNULL(var_addr);
} else {
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base);
var_addr = model_param.var_base + offset - model_param.logic_var_base;
}
GE_CHECK_NOTNULL(var_addr);
return SUCCESS;
}



+ 2
- 2
ge/graph/manager/graph_var_manager.cc View File

@@ -212,7 +212,7 @@ rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
if (var_offset_map_.count(offset) > 0) {
return var_offset_map_[offset];
}
return RT_MEMORY_HBM;
return RT_MEMORY_RESERVED;
}

VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
@@ -660,7 +660,7 @@ rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init.");
return RT_MEMORY_HBM;
return RT_MEMORY_RESERVED;
}
return var_resource_->GetVarMemType(offset);
}


Loading…
Cancel
Save