gather_scatter.cu 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. #include "memory_access_utils.h"
  6. namespace cg = cooperative_groups;
  7. namespace td_data {
  8. constexpr int granularity = 16;
  9. }
  10. template <typename T>
  11. __global__ void gather_tokens_impl(T* retained_tokens,
  12. const T* activations,
  13. int32_t* gather_indices,
  14. int32_t sampled_tokens,
  15. int32_t channels,
  16. int32_t read_batch_stride,
  17. int32_t read_seq_stride,
  18. int32_t write_batch_stride,
  19. int32_t write_seq_stride)
  20. {
  21. constexpr int mem_vals_t = td_data::granularity / sizeof(T);
  22. cg::thread_block tb = cg::this_thread_block();
  23. const int gather_idx = gather_indices[tb.group_index().x * sampled_tokens + tb.group_index().y];
  24. const int read_offset = read_batch_stride * tb.group_index().x + read_seq_stride * gather_idx;
  25. const int write_offset =
  26. write_batch_stride * tb.group_index().x + write_seq_stride * tb.group_index().y;
  27. for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += blockDim.x * mem_vals_t) {
  28. T local_data[mem_vals_t];
  29. mem_access::load_global<td_data::granularity>(local_data, activations + read_offset + i);
  30. mem_access::store_global<td_data::granularity>(retained_tokens + write_offset + i,
  31. local_data);
  32. }
  33. }
  34. template <typename T>
  35. void launch_gather_tokens(T* retained_tokens,
  36. T* activations,
  37. int32_t* gather_indices,
  38. int32_t batch_size,
  39. int32_t sampled_tokens,
  40. int32_t channels,
  41. int32_t read_batch_stride,
  42. int32_t read_seq_stride,
  43. int32_t write_batch_stride,
  44. int32_t write_seq_stride,
  45. cudaStream_t stream)
  46. {
  47. constexpr int mem_vals_t = td_data::granularity / sizeof(T);
  48. const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
  49. const int threads = (load_steps >= 1024) ? 1024 : load_steps;
  50. dim3 block(threads);
  51. dim3 grid(batch_size, sampled_tokens);
  52. gather_tokens_impl<T><<<grid, block, 0, stream>>>(retained_tokens,
  53. activations,
  54. gather_indices,
  55. sampled_tokens,
  56. channels,
  57. read_batch_stride,
  58. read_seq_stride,
  59. write_batch_stride,
  60. write_seq_stride);
  61. }
  62. template void launch_gather_tokens<float>(float*,
  63. float*,
  64. int32_t*,
  65. int32_t,
  66. int32_t,
  67. int32_t,
  68. int32_t,
  69. int32_t,
  70. int32_t,
  71. int32_t,
  72. cudaStream_t);
  73. template void launch_gather_tokens<__half>(__half*,
  74. __half*,
  75. int32_t*,
  76. int32_t,
  77. int32_t,
  78. int32_t,
  79. int32_t,
  80. int32_t,
  81. int32_t,
  82. int32_t,
  83. cudaStream_t);
  84. template <typename T>
  85. __global__ void scatter_tokens_impl(T* all_activations,
  86. const T* layer_activations,
  87. int32_t* gather_indices,
  88. int32_t retained_tokens,
  89. int32_t channels,
  90. int32_t read_batch_stride,
  91. int32_t read_seq_stride,
  92. int32_t write_batch_stride,
  93. int32_t write_seq_stride)
  94. {
  95. constexpr int mem_vals_t = td_data::granularity / sizeof(T);
  96. cg::thread_block tb = cg::this_thread_block();
  97. const int gather_idx =
  98. gather_indices[tb.group_index().x * retained_tokens + tb.group_index().y];
  99. const int read_offset =
  100. read_batch_stride * tb.group_index().x + read_seq_stride * tb.group_index().y;
  101. const int write_offset =
  102. write_batch_stride * tb.group_index().x + write_seq_stride * gather_idx;
  103. for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += mem_vals_t * blockDim.x) {
  104. T local_data[mem_vals_t];
  105. mem_access::load_global<td_data::granularity>(local_data,
  106. layer_activations + read_offset + i);
  107. mem_access::store_global<td_data::granularity>(all_activations + write_offset + i,
  108. local_data);
  109. }
  110. }
  111. template <typename T>
  112. void launch_scatter_tokens(T* all_activations,
  113. T* layer_activations,
  114. int32_t* gather_indices,
  115. int32_t batch_size,
  116. int32_t sampled_tokens,
  117. int32_t channels,
  118. int32_t read_batch_stride,
  119. int32_t read_seq_stride,
  120. int32_t write_batch_stride,
  121. int32_t write_seq_stride,
  122. cudaStream_t stream)
  123. {
  124. constexpr int mem_vals_t = td_data::granularity / sizeof(T);
  125. const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
  126. const int threads = (load_steps >= 1024) ? 1024 : load_steps;
  127. dim3 block(threads);
  128. dim3 grid(batch_size, sampled_tokens);
  129. scatter_tokens_impl<T><<<grid, block, 0, stream>>>(all_activations,
  130. layer_activations,
  131. gather_indices,
  132. sampled_tokens,
  133. channels,
  134. read_batch_stride,
  135. read_seq_stride,
  136. write_batch_stride,
  137. write_seq_stride);
  138. }
  139. template void launch_scatter_tokens<float>(float*,
  140. float*,
  141. int32_t*,
  142. int32_t,
  143. int32_t,
  144. int32_t,
  145. int32_t,
  146. int32_t,
  147. int32_t,
  148. int32_t,
  149. cudaStream_t);
  150. template void launch_scatter_tokens<__half>(__half*,
  151. __half*,
  152. int32_t*,
  153. int32_t,
  154. int32_t,
  155. int32_t,
  156. int32_t,
  157. int32_t,
  158. int32_t,
  159. int32_t,
  160. cudaStream_t);