Browse Source

!15605 Fix async gpu memcpy issue

From: @TFbunny
Reviewed-by: @wilfchen,@liangchenghui
Signed-off-by: @liangchenghui
pull/15605/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
6f11902b18
4 changed files with 8 additions and 4 deletions
  1. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h
  2. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h
  3. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h
  4. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h

+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -72,6 +72,7 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy gpu memory.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
CHECK_CUDA_RET_WITH_ERROR(
kernel_node_,
cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -62,6 +62,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memoy failed.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
if (x == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null.";
return false;


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -60,6 +60,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memoy failed.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
if (d_batch_mean == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null.";
return false;


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -52,6 +52,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
cudaMemcpyAsync(&array_input_[0], input, input_size_ * sizeof(T),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync sampled_candidates failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
for (const auto item : array_input_) {
set_input_.insert(item);
}


Loading…
Cancel
Save