quantize.cu 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. /*
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. */
  4. #include "ds_kernel_utils.h"
  5. #include "memory_access_utils.h"
  6. #include "quantization.h"
  7. #include "quantization_utils.h"
  8. #include "reduction_utils.h"
  9. namespace cg = cooperative_groups;
  10. /*
  11. Pure quantization kernel with no fusion.
  12. */
  13. template <int q_bits,
  14. quantize::Type quant_type,
  15. int UNROLL,
  16. int internal_unroll,
  17. int threads_per_group,
  18. int max_threads>
  19. __global__ void cached_quantization(int8_t* __restrict__ output_data,
  20. float* __restrict__ params,
  21. const __half* __restrict__ input_data,
  22. int groups,
  23. int elems_per_group)
  24. {
  25. cg::thread_block tb = cg::this_thread_block();
  26. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  27. // Indexing offsets
  28. const int block_offset =
  29. (tb.group_index().x * (max_threads / threads_per_group) * elems_per_group) +
  30. (tb.thread_index().y * elems_per_group);
  31. const int elem_offset = tb.thread_index().x * quantize::h_per_load;
  32. const int base_offset = block_offset + elem_offset;
  33. const int stride = tb.size() * quantize::h_per_load;
  34. const __half* input_base = input_data + base_offset; //..
  35. __half2 local_buffer[UNROLL * internal_unroll * quantize::h2_per_load];
  36. #pragma unroll
  37. for (int i = 0; i < UNROLL; i++) {
  38. // Convenience helper, should resolve to register indices and not realize.
  39. __half2* iteration_buffer = local_buffer + i * internal_unroll * quantize::h2_per_load;
  40. #pragma unroll
  41. for (int j = 0; j < internal_unroll; j++) {
  42. const int iteration = i * internal_unroll + j;
  43. mem_access::load_global<quantize::granularity>(
  44. iteration_buffer + j * quantize::h2_per_load,
  45. input_base + iteration * stride,
  46. elem_offset + iteration * stride < elems_per_group);
  47. }
  48. }
  49. quantize::
  50. local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
  51. local_buffer, params, output_data, elems_per_group, groups);
  52. }
  53. /********* Launcher methods ***********/
  54. #define LAUNCH_CACHED_QUANT( \
  55. q_bits, quant_type, unroll_factor, internal_unroll, threads_per_group, max_threads) \
  56. cached_quantization<q_bits, \
  57. quant_type, \
  58. unroll_factor, \
  59. internal_unroll, \
  60. threads_per_group, \
  61. max_threads> \
  62. <<<grid, block, 0, stream>>>(output_data, params, input_data, groups, elems_per_group);
  63. template <int numBits, quantize::Type qType>
  64. void launch_quant(int8_t* output_data,
  65. float* params,
  66. const __half* input_data,
  67. const int groups,
  68. const int elems_per_group,
  69. cudaStream_t stream)
  70. {
  71. constexpr int max_threads = 256;
  72. constexpr int internal_unroll = 2;
  73. const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
  74. const int h_per_step = is_subblock_schedule ? quantize::h_per_load
  75. : quantize::h_per_load * internal_unroll;
  76. // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
  77. // warp-sized blocks rather than stepping up to 64/96 threads
  78. const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step);
  79. const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;
  80. const int groups_per_block =
  81. is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
  82. const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;
  83. dim3 block(threads_per_group, groups_per_block);
  84. dim3 grid(groups_launch);
  85. const int elems_per_step = threads_per_group * h_per_step;
  86. const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;
  87. if (is_subblock_schedule) {
  88. // <=128
  89. if (threads_per_group == 1) {
  90. LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 1, max_threads);
  91. } else if (threads_per_group == 2) {
  92. LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 2, max_threads);
  93. } else if (threads_per_group == 4) {
  94. LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 4, max_threads);
  95. } else if (threads_per_group == 8) {
  96. LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 8, max_threads);
  97. } else if (threads_per_group == 16) {
  98. LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 16, max_threads);
  99. }
  100. } else if (external_unroll == 1) {
  101. // 129 - 4096 elems
  102. // (this can launch with 1-7 warps as well)
  103. LAUNCH_CACHED_QUANT(numBits, qType, 1, internal_unroll, max_threads, max_threads);
  104. } else if (external_unroll == 2) {
  105. // 4097 - 8192 elems
  106. LAUNCH_CACHED_QUANT(numBits, qType, 2, internal_unroll, max_threads, max_threads);
  107. } else if (external_unroll == 3) {
  108. // 8193 - 12288 elems
  109. LAUNCH_CACHED_QUANT(numBits, qType, 3, internal_unroll, max_threads, max_threads);
  110. } else if (external_unroll == 4) {
  111. // 12289 - 16384 elems
  112. LAUNCH_CACHED_QUANT(numBits, qType, 4, internal_unroll, max_threads, max_threads);
  113. }
  114. }
  115. template void launch_quant<8, quantize::Type::Symmetric>(int8_t*,
  116. float*,
  117. const __half*,
  118. int,
  119. int,
  120. cudaStream_t);
  121. template void launch_quant<8, quantize::Type::Asymmetric>(int8_t*,
  122. float*,
  123. const __half*,
  124. int,
  125. int,
  126. cudaStream_t);
  127. template void launch_quant<8, quantize::Type::IntegerSymmetric>(int8_t*,
  128. float*,
  129. const __half*,
  130. int,
  131. int,
  132. cudaStream_t);
  133. template void launch_quant<4, quantize::Type::Symmetric>(int8_t*,
  134. float*,
  135. const __half*,
  136. int,
  137. int,
  138. cudaStream_t);
  139. template void launch_quant<4, quantize::Type::Asymmetric>(int8_t*,
  140. float*,
  141. const __half*,
  142. int,
  143. int,
  144. cudaStream_t);
  145. template void launch_quant<4, quantize::Type::IntegerSymmetric>(int8_t*,
  146. float*,
  147. const __half*,
  148. int,
  149. int,
  150. cudaStream_t);