Browse Source

fix repeat rtmalloc device mem

tags/v1.3.0
zhou_lili 3 years ago
parent
commit
e8fcd806f6
2 changed files with 10 additions and 9 deletions
  1. +4
    -8
      ge/single_op/single_op.cc
  2. +6
    -1
      ge/single_op/single_op_manager.cc

+ 4
- 8
ge/single_op/single_op.cc View File

@@ -84,7 +84,7 @@ Status CalInputsHostMemSize(const std::vector<DataBuffer> &inputs,
inputs_size.emplace_back(index, input_size);
GE_CHK_STATUS_RET(CheckInt64AddOverflow(total_size, input_size), "Total size is beyond the INT64_MAX.");
total_size += input_size;
GELOGD("The %zu input mem type is host, tensor size is %ld.", index, input_size);
GELOGD("The %zu input mem type is host, the tensor size is %ld.", index, input_size);
}
index++;
}
@@ -99,20 +99,16 @@ Status UpdateInputsBufferAddr(StreamResource *stream_resource, rtStream_t stream
const std::vector<std::pair<size_t, uint64_t>> &inputs_size,
std::vector<DataBuffer> &update_buffers) {
GE_CHECK_NOTNULL(stream_resource);
if (stream_resource->Init() != SUCCESS) {
GELOGE(FAILED, "[Malloc][Memory]Failed to malloc device buffer.");
return FAILED;
}
auto dst_addr = reinterpret_cast<uint8_t *>(stream_resource->GetDeviceBufferAddr());
// copy host mem from input_buffer to device mem of dst_addr
for (const auto &input_size : inputs_size) {
size_t index = input_size.first;
auto index = input_size.first;
auto size = input_size.second;
GELOGD("Do H2D for %zu input, dst size is %zu, src length is %lu.", index, size, update_buffers[index].length);
GELOGD("Do h2d for %zu input, dst size is %zu, src length is %lu.", index, size, update_buffers[index].length);
GE_CHK_RT_RET(rtMemcpyAsync(dst_addr, size, update_buffers[index].data, update_buffers[index].length,
RT_MEMCPY_HOST_TO_DEVICE_EX, stream));
update_buffers[index].data = dst_addr;
dst_addr = reinterpret_cast<uint8_t *>(dst_addr + size);
dst_addr = dst_addr + size;
}
return SUCCESS;
}


+ 6
- 1
ge/single_op/single_op_manager.cc View File

@@ -81,8 +81,13 @@ StreamResource *SingleOpManager::GetResource(uintptr_t resource_id, rtStream_t s
auto it = stream_resources_.find(resource_id);
StreamResource *res = nullptr;
if (it == stream_resources_.end()) {
res = new (std::nothrow) StreamResource(resource_id);
res = new(std::nothrow) StreamResource(resource_id);
if (res != nullptr) {
if (res->Init() != SUCCESS) {
GELOGE(FAILED, "[Malloc][Memory]Failed to malloc device buffer.");
delete res;
return nullptr;
}
res->SetStream(stream);
stream_resources_.emplace(resource_id, res);
}


Loading…
Cancel
Save