You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_cudacode.py 1.4 kB

2 months ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.cpp_extension import load_inline
  4. # 更简单的实现:只优化ReLU部分,矩阵乘法使用PyTorch
  5. relu_source = """
  6. #include <torch/extension.h>
  7. #include <cuda_runtime.h>
  8. __global__ void relu_kernel(const float* x, float* y, int size) {
  9. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  10. if (idx < size) {
  11. y[idx] = fmaxf(x[idx], 0.f);
  12. }
  13. }
  14. torch::Tensor relu_cuda(torch::Tensor x) {
  15. auto size = x.numel();
  16. auto y = torch::empty_like(x);
  17. const int block_size = 256;
  18. int num_blocks = (size + block_size - 1) / block_size;
  19. relu_kernel<<<num_blocks, block_size>>>(x.data_ptr<float>(), y.data_ptr<float>(), size);
  20. return y;
  21. }
  22. """
  23. relu_cpp_source = """
  24. torch::Tensor relu_cuda(torch::Tensor x);
  25. """
  26. # Compile the inline CUDA code
  27. relu = load_inline(
  28. name="relu",
  29. cpp_sources=relu_cpp_source,
  30. cuda_sources=relu_source,
  31. functions=["relu_cuda"],
  32. verbose=True
  33. )
  34. class ModelNew(torch.nn.Module):
  35. def __init__(self, weight):
  36. super(ModelNew, self).__init__()
  37. self.weight = nn.Parameter(weight)
  38. self.relu = relu # The module containing the kernel
  39. def forward(self, x):
  40. # 使用PyTorch的矩阵乘法,只优化ReLU部分
  41. x = torch.matmul(x, self.weight)
  42. return self.relu.relu_cuda(x)