custom_cuda_kernel.cu 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. __global__ void param_update_kernel(const float* input, __half* output, int size)
  6. {
  7. int id = blockIdx.x * blockDim.x + threadIdx.x;
  8. if (id < size) { output[id] = (__half)input[id]; }
  9. }
  10. void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
  11. {
  12. int threads = 1024;
  13. dim3 grid_dim((size - 1) / threads + 1);
  14. dim3 block_dim(threads);
  15. param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
  16. }
  17. __global__ void param_update_kernel_half(const float* input, __half* output, int size)
  18. {
  19. int id = blockIdx.x * blockDim.x + threadIdx.x;
  20. __half2* output_cast = reinterpret_cast<__half2*>(output);
  21. if (id < size) {
  22. float input_f = input[id];
  23. __half2* input_h = reinterpret_cast<__half2*>(&input_f);
  24. output_cast[id] = *input_h;
  25. }
  26. }
  27. void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream)
  28. {
  29. int threads = 1024;
  30. size /= 2;
  31. dim3 grid_dim((size - 1) / threads + 1);
  32. dim3 block_dim(threads);
  33. param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
  34. }