custom_cuda_kernel.cu 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. #include "custom_cuda_layers.h"
  2. __global__ void param_update_kernel(const float* input, __half* output, int size)
  3. {
  4. int id = blockIdx.x * blockDim.x + threadIdx.x;
  5. if (id < size) { output[id] = (__half)input[id]; }
  6. }
  7. void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
  8. {
  9. int threads = 1024;
  10. dim3 grid_dim((size - 1) / threads + 1);
  11. dim3 block_dim(threads);
  12. param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
  13. }
  14. __global__ void param_update_kernel_half(const float* input, __half* output, int size)
  15. {
  16. int id = blockIdx.x * blockDim.x + threadIdx.x;
  17. __half2* output_cast = reinterpret_cast<__half2*>(output);
  18. if (id < size) {
  19. float input_f = input[id];
  20. __half2* input_h = reinterpret_cast<__half2*>(&input_f);
  21. output_cast[id] = *input_h;
  22. }
  23. }
  24. void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream)
  25. {
  26. int threads = 1024;
  27. size /= 2;
  28. dim3 grid_dim((size - 1) / threads + 1);
  29. dim3 block_dim(threads);
  30. param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
  31. }