|
|
@@ -379,17 +379,24 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co |
|
|
|
/// |
|
|
|
Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset, |
|
|
|
uint8_t *&var_addr) { |
|
|
|
if (ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset) == RT_MEMORY_RDMA_HBM) { |
|
|
|
if (offset < 0) { |
|
|
|
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset)); |
|
|
|
rtMemType_t mem_type = ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset); |
|
|
|
switch (mem_type) { |
|
|
|
case RT_MEMORY_RDMA_HBM: |
|
|
|
if (offset < 0) { |
|
|
|
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset)); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
var_addr = reinterpret_cast<uint8_t *>(offset); |
|
|
|
break; |
|
|
|
case RT_MEMORY_HBM: |
|
|
|
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); |
|
|
|
var_addr = model_param.var_base + offset - model_param.logic_var_base; |
|
|
|
break; |
|
|
|
default: |
|
|
|
GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
var_addr = reinterpret_cast<uint8_t *>(offset); |
|
|
|
GE_CHECK_NOTNULL(var_addr); |
|
|
|
} else { |
|
|
|
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); |
|
|
|
var_addr = model_param.var_base + offset - model_param.logic_var_base; |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(var_addr); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|