You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_cudacode.py 1.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. from torch.utils.cpp_extension import load_inline
  3. relu_source = """
  4. #include <torch/extension.h>
  5. #include <cuda_runtime.h>
  6. __global__ void relu_kernel(const float* x, float* y, int size) {
  7. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  8. if (idx < size) {
  9. y[idx] = fmaxf(x[idx], 0.f);
  10. }
  11. }
  12. torch::Tensor relu_cuda(torch::Tensor x) {
  13. auto size = x.numel();
  14. auto y = torch::empty_like(x);
  15. const int block_size = 256;
  16. int num_blocks = (size + block_size - 1) / block_size;
  17. relu_kernel<<<num_blocks, block_size>>>(x.data_ptr<float>(), y.data_ptr<float>(), size);
  18. return y;
  19. }
  20. """
  21. relu_cpp_source = """
  22. torch::Tensor relu_cuda(torch::Tensor x);
  23. """
  24. # Compile the inline CUDA code
  25. relu = load_inline(
  26. name="relu",
  27. cpp_sources=relu_cpp_source,
  28. cuda_sources=relu_source,
  29. functions=["relu_cuda"],
  30. verbose=True
  31. )
  32. class ModelNew(torch.nn.Module):
  33. def __init__(self):
  34. super(ModelNew, self).__init__()
  35. self.relu = relu # The module containing the kernel
  36. def forward(self, x):
  37. return self.relu.relu_cuda(x)