token_sort.cu 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <cassert>
  5. #include "custom_cuda_layers.h"
  6. #include "memory_access_utils.h"
  7. namespace cg = cooperative_groups;
  8. namespace td_sort {
  9. constexpr int threads = 512;
  10. constexpr int granularity = 16;
  11. constexpr int mem_vals = granularity / sizeof(int32_t);
  12. constexpr int max_buffer_size = (threads + 1) * mem_vals;
  13. #ifdef __HIP_PLATFORM_HCC__
  14. constexpr int warp_size = 64;
  15. #else
  16. constexpr int warp_size = 32;
  17. #endif
  18. constexpr int max_warps = threads / warp_size;
  19. } // namespace td_sort
  20. template <int VALS_PER_THREAD>
  21. __global__ void scan_sort(int32_t* data, int reserved_tokens, int original_tokens)
  22. {
  23. cg::thread_block tb = cg::this_thread_block();
  24. cg::thread_block_tile<td_sort::warp_size> warp = cg::tiled_partition<td_sort::warp_size>(tb);
  25. __shared__ int32_t indices_buffer[td_sort::max_buffer_size];
  26. __shared__ int32_t intermediate_buffer[td_sort::max_warps];
  27. __shared__ int32_t sorted_indices_buffer[td_sort::max_buffer_size];
  28. for (int i = tb.thread_index().x * td_sort::mem_vals; i < original_tokens + 1;
  29. i += tb.group_dim().x * td_sort::mem_vals) {
  30. uint32_t zeros[td_sort::mem_vals] = {0, 0, 0, 0};
  31. mem_access::store_shared<td_sort::granularity>(indices_buffer + i, zeros);
  32. }
  33. int32_t local_vals[VALS_PER_THREAD];
  34. // We flatten layers/batch into a single indexing dimension
  35. int32_t* data_block = data + tb.group_index().x * reserved_tokens;
  36. // The next two loops really could be fused for a more logical code layout, but don't want to
  37. // move the barrier forward
  38. #pragma unroll
  39. for (int i = 0; i < VALS_PER_THREAD; i++) {
  40. const int iter_idx = i * td_sort::threads + tb.thread_index().x;
  41. if (iter_idx < reserved_tokens) {
  42. mem_access::load_global<sizeof(int32_t)>(local_vals + i, data_block + iter_idx);
  43. } else {
  44. local_vals[i] = 0;
  45. }
  46. }
  47. tb.sync();
  48. #pragma unroll
  49. for (int i = 0; i < VALS_PER_THREAD; i++) {
  50. const int iter_idx = i * td_sort::threads + tb.thread_index().x;
  51. if (iter_idx < reserved_tokens) {
  52. const int32_t one = 1;
  53. mem_access::store_shared<sizeof(int32_t)>(indices_buffer + local_vals[i], &one);
  54. }
  55. }
  56. tb.sync();
  57. int32_t local_input[td_sort::mem_vals];
  58. mem_access::load_shared<td_sort::granularity>(
  59. local_input, indices_buffer + tb.thread_index().x * td_sort::mem_vals);
  60. int32_t reduce_vals[td_sort::mem_vals];
  61. reduce_vals[0] = local_input[0];
  62. #pragma unroll
  63. for (int i = 1; i < td_sort::mem_vals; i++) {
  64. reduce_vals[i] = local_input[i] + reduce_vals[i - 1];
  65. }
  66. int32_t step_1_val = reduce_vals[td_sort::mem_vals - 1];
  67. // Short span exclusive scan algorithm (less work efficient)
  68. #pragma unroll
  69. for (int i = 1; i < td_sort::warp_size; i *= 2) {
  70. int32_t step_val = warp.shfl_up(step_1_val, i);
  71. step_1_val = (warp.thread_rank() < i) ? step_1_val : step_1_val + step_val;
  72. }
  73. if (warp.thread_rank() == td_sort::warp_size - 1) {
  74. mem_access::store_shared<sizeof(int32_t)>(intermediate_buffer + warp.meta_group_rank(),
  75. &step_1_val);
  76. }
  77. tb.sync();
  78. if (warp.meta_group_rank() == 0) {
  79. int32_t step_2_val = 0;
  80. if (warp.thread_rank() < td_sort::max_warps) {
  81. mem_access::load_shared<sizeof(int32_t)>(&step_2_val,
  82. intermediate_buffer + warp.thread_rank());
  83. }
  84. #pragma unroll
  85. for (int i = 1; i < td_sort::warp_size; i *= 2) {
  86. int32_t step_val = warp.shfl_up(step_2_val, i);
  87. step_2_val = (warp.thread_rank() < i) ? step_2_val : step_2_val + step_val;
  88. }
  89. if (warp.thread_rank() < td_sort::max_warps) {
  90. mem_access::store_shared<sizeof(int32_t)>(intermediate_buffer + warp.thread_rank(),
  91. &step_2_val);
  92. }
  93. }
  94. tb.sync();
  95. int step_2_val = 0;
  96. if (warp.meta_group_rank() > 0) {
  97. mem_access::load_shared<sizeof(int32_t)>(&step_2_val,
  98. intermediate_buffer + warp.meta_group_rank() - 1);
  99. }
  100. const int thread_offset = reduce_vals[td_sort::mem_vals - 1];
  101. #pragma unroll
  102. for (int i = 0; i < td_sort::mem_vals; i++) {
  103. reduce_vals[i] += step_1_val + step_2_val - thread_offset;
  104. }
  105. mem_access::store_shared<td_sort::granularity>(
  106. indices_buffer + tb.thread_index().x * td_sort::mem_vals, reduce_vals);
  107. if (tb.thread_index().x == 0) {
  108. indices_buffer[original_tokens] = original_tokens - indices_buffer[original_tokens];
  109. }
  110. tb.sync();
  111. for (int i = 0; i < VALS_PER_THREAD; i++) {
  112. const int iter_idx = i * td_sort::threads + tb.thread_index().x;
  113. if (iter_idx < reserved_tokens) {
  114. if (local_vals[i] == 0) {
  115. int zero = 0;
  116. mem_access::store_shared<sizeof(int32_t)>(sorted_indices_buffer, &zero);
  117. } else {
  118. int sorted_idx;
  119. mem_access::load_shared<sizeof(int32_t)>(&sorted_idx,
  120. indices_buffer + local_vals[i] - 1);
  121. mem_access::store_shared<sizeof(int32_t)>(sorted_indices_buffer + sorted_idx,
  122. local_vals + i);
  123. }
  124. }
  125. }
  126. tb.sync();
  127. #pragma unroll
  128. for (int i = 0; i < VALS_PER_THREAD; i++) {
  129. const int iter_idx = i * td_sort::threads + tb.thread_index().x;
  130. if (iter_idx < reserved_tokens) {
  131. int32_t store_val;
  132. mem_access::load_shared<sizeof(int32_t)>(&store_val, sorted_indices_buffer + iter_idx);
  133. mem_access::store_global<sizeof(int32_t)>(data_block + iter_idx, &store_val);
  134. }
  135. }
  136. }
  137. void launch_token_sort(int32_t* indices,
  138. int layers,
  139. int batch_size,
  140. int reserved_size,
  141. int original_tokens,
  142. cudaStream_t stream)
  143. {
  144. // Each sort is completely independent, can flatten this dimension
  145. dim3 grid(layers * batch_size);
  146. dim3 block(td_sort::threads);
  147. const int vals_per_thread = (reserved_size + td_sort::threads - 1) / td_sort::threads;
  148. if (vals_per_thread == 1) {
  149. scan_sort<1><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
  150. } else if (vals_per_thread == 2) {
  151. scan_sort<2><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
  152. } else if (vals_per_thread == 3) {
  153. scan_sort<3><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
  154. } else if (vals_per_thread == 4) {
  155. scan_sort<4><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
  156. } else {
  157. assert(false);
  158. }
  159. }