From e8fcd806f673ab7a3171d177fdae009df43f57da Mon Sep 17 00:00:00 2001 From: zhou_lili Date: Wed, 7 Apr 2021 15:54:03 +0800 Subject: [PATCH] fix repeat rtmalloc device mem --- ge/single_op/single_op.cc | 12 ++++-------- ge/single_op/single_op_manager.cc | 7 ++++++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ge/single_op/single_op.cc b/ge/single_op/single_op.cc index c305eea9..4b3f17cf 100755 --- a/ge/single_op/single_op.cc +++ b/ge/single_op/single_op.cc @@ -84,7 +84,7 @@ Status CalInputsHostMemSize(const std::vector &inputs, inputs_size.emplace_back(index, input_size); GE_CHK_STATUS_RET(CheckInt64AddOverflow(total_size, input_size), "Total size is beyond the INT64_MAX."); total_size += input_size; - GELOGD("The %zu input mem type is host, tensor size is %ld.", index, input_size); + GELOGD("The %zu input mem type is host, the tensor size is %ld.", index, input_size); } index++; } @@ -99,20 +99,16 @@ Status UpdateInputsBufferAddr(StreamResource *stream_resource, rtStream_t stream const std::vector> &inputs_size, std::vector &update_buffers) { GE_CHECK_NOTNULL(stream_resource); - if (stream_resource->Init() != SUCCESS) { - GELOGE(FAILED, "[Malloc][Memory]Failed to malloc device buffer."); - return FAILED; - } auto dst_addr = reinterpret_cast(stream_resource->GetDeviceBufferAddr()); // copy host mem from input_buffer to device mem of dst_addr for (const auto &input_size : inputs_size) { - size_t index = input_size.first; + auto index = input_size.first; auto size = input_size.second; - GELOGD("Do H2D for %zu input, dst size is %zu, src length is %lu.", index, size, update_buffers[index].length); + GELOGD("Do h2d for %zu input, dst size is %zu, src length is %lu.", index, size, update_buffers[index].length); GE_CHK_RT_RET(rtMemcpyAsync(dst_addr, size, update_buffers[index].data, update_buffers[index].length, RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); update_buffers[index].data = dst_addr; - dst_addr = reinterpret_cast(dst_addr + size); + dst_addr = dst_addr + size; } return SUCCESS; } diff --git a/ge/single_op/single_op_manager.cc b/ge/single_op/single_op_manager.cc index 6246d6a1..667e987b 100644 --- a/ge/single_op/single_op_manager.cc +++ b/ge/single_op/single_op_manager.cc @@ -81,8 +81,13 @@ StreamResource *SingleOpManager::GetResource(uintptr_t resource_id, rtStream_t s auto it = stream_resources_.find(resource_id); StreamResource *res = nullptr; if (it == stream_resources_.end()) { - res = new (std::nothrow) StreamResource(resource_id); + res = new(std::nothrow) StreamResource(resource_id); if (res != nullptr) { + if (res->Init() != SUCCESS) { + GELOGE(FAILED, "[Malloc][Memory]Failed to malloc device buffer."); + delete res; + return nullptr; + } res->SetStream(stream); stream_resources_.emplace(resource_id, res); }