diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index 5d440f00..5ec70691 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -542,11 +542,7 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type); return ge::INTERNAL_ERROR; } - result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset); - if (result != SUCCESS) { - GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed."); - return ge::INTERNAL_ERROR; - } + if (var_resource_ == nullptr) { REPORT_INNER_ERROR("E19999", "VarManager has not been init, memory_type:%d, session_id:%lu, " "check invalid", memory_type, session_id_); @@ -554,31 +550,46 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen return ge::INTERNAL_ERROR; } - result = var_resource_->SaveVarAddr( - var_name, tensor_desc, reinterpret_cast(static_cast(mem_offset)), memory_type); - if (result != SUCCESS) { - GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed."); - return ge::INTERNAL_ERROR; + ge::GeTensorDesc cur_tensor_desc; + int64_t cur_tensor_desc_size = 0; + result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc); + // reuse old format variable memory + if (result == SUCCESS) { + result = var_resource_->GetVarAddr( + var_name, tensor_desc, reinterpret_cast(reinterpret_cast(&mem_offset)), memory_type); + if (result == SUCCESS) { + result = TensorUtils::GetSize(cur_tensor_desc, cur_tensor_desc_size); + GELOGD("tensor_desc_size is %ld, cur_tensor_desc_size is %ld, memoffset is %zu", tensor_desc_size, + cur_tensor_desc_size, mem_offset); + } } - result = var_resource_->GetVarAddr( - var_name, tensor_desc, reinterpret_cast(reinterpret_cast(&mem_offset)), memory_type); - if (result != SUCCESS) { - GELOGE(ge::INTERNAL_ERROR, "GetVarAddr by offset failed."); - return ge::INTERNAL_ERROR; - } + bool can_not_reuse_old_memory = (result != SUCCESS) || (tensor_desc_size > cur_tensor_desc_size); + if (can_not_reuse_old_memory) { + result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset); + if (result != SUCCESS) { + GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed."); + return ge::INTERNAL_ERROR; + } - ge::GeTensorDesc cur_tensor_desc; + result = var_resource_->SaveVarAddr( + var_name, tensor_desc, reinterpret_cast(static_cast(mem_offset)), memory_type); + if (result != SUCCESS) { + GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed."); + return ge::INTERNAL_ERROR; + } + } + // old not exist only save new tensor result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc); if (result != SUCCESS) { var_resource_->SetVarAddr(var_name, tensor_desc, reinterpret_cast(static_cast(mem_offset)), memory_type); return SUCCESS; } - - if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() || - cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() || - cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) { + bool format_changed = cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() || + cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() || + cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims(); + if (format_changed) { GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(), ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),