quantize.cu 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "ds_kernel_utils.h"
  5. #include "memory_access_utils.h"
  6. #include "quantization.h"
  7. #include "quantization_utils.h"
  8. #include "reduction_utils.h"
  9. namespace cg = cooperative_groups;
  10. /*
  11. Pure quantization kernel with no fusion.
  12. */
  13. template <int q_bits,
  14. quantize::Type quant_type,
  15. int UNROLL,
  16. int internal_unroll,
  17. int threads_per_group,
  18. int max_threads>
  19. __global__ void cached_quantization(int8_t* __restrict__ output_data,
  20. float* __restrict__ params,
  21. const __half* __restrict__ input_data,
  22. int groups,
  23. int elems_per_group)
  24. {
  25. cg::thread_block tb = cg::this_thread_block();
  26. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  27. // Indexing offsets
  28. const int block_offset =
  29. (tb.group_index().x * (max_threads / threads_per_group) * elems_per_group) +
  30. (tb.thread_index().y * elems_per_group);
  31. const int elem_offset = tb.thread_index().x * quantize::h_per_load;
  32. const int base_offset = block_offset + elem_offset;
  33. const int stride = tb.size() * quantize::h_per_load;
  34. const __half* input_base = input_data + base_offset; //..
  35. __half2 local_buffer[UNROLL * internal_unroll * quantize::h2_per_load];
  36. #pragma unroll
  37. for (int i = 0; i < UNROLL; i++) {
  38. // Convenience helper, should resolve to register indices and not realize.
  39. __half2* iteration_buffer = local_buffer + i * internal_unroll * quantize::h2_per_load;
  40. #pragma unroll
  41. for (int j = 0; j < internal_unroll; j++) {
  42. const int iteration = i * internal_unroll + j;
  43. mem_access::load_global<quantize::granularity>(
  44. iteration_buffer + j * quantize::h2_per_load,
  45. input_base + iteration * stride,
  46. elem_offset + iteration * stride < elems_per_group);
  47. }
  48. }
  49. quantize::
  50. local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
  51. local_buffer, params, output_data, elems_per_group, groups);
  52. }
  53. /********* Launcher methods ***********/
  54. #define LAUNCH_CACHED_QUANT_CALL(q_bits, quant_type) \
  55. cached_quantization<q_bits, \
  56. quant_type, \
  57. unroll_factor, \
  58. internal_unroll_l, \
  59. threads_per_group, \
  60. max_threads> \
  61. <<<grid, block, 0, stream>>>(output_data, params, input_data, groups, elems_per_group);
  62. #define LAUNCH_CACHED_QUANT( \
  63. q_bits, quant_type, unroll_factor_in, internal_unroll_in, threads_per_group_in) \
  64. const int unroll_factor = unroll_factor_in; \
  65. const int internal_unroll_l = internal_unroll_in; \
  66. const int threads_per_group = threads_per_group_in; \
  67. if (q_bits == 4) { \
  68. if (quant_type == quantize::Type::Asymmetric) { \
  69. LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Asymmetric) \
  70. } else { \
  71. LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Symmetric) \
  72. } \
  73. } else { \
  74. if (quant_type == quantize::Type::Asymmetric) { \
  75. LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Asymmetric) \
  76. } else { \
  77. LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Symmetric) \
  78. } \
  79. }
  80. void launch_quant(int8_t* output_data,
  81. float* params,
  82. const __half* input_data,
  83. const int groups,
  84. const int elems_per_group,
  85. const int num_bits,
  86. const quantize::Type quant_type,
  87. cudaStream_t stream)
  88. {
  89. constexpr int max_threads = 256;
  90. constexpr int internal_unroll = 2;
  91. const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
  92. const int h_per_step = is_subblock_schedule ? quantize::h_per_load
  93. : quantize::h_per_load * internal_unroll;
  94. // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
  95. // warp-sized blocks rather than stepping up to 64/96 threads
  96. const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step);
  97. const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;
  98. const int groups_per_block =
  99. is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
  100. const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;
  101. dim3 block(threads_per_group, groups_per_block);
  102. dim3 grid(groups_launch);
  103. const int elems_per_step = threads_per_group * h_per_step;
  104. const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;
  105. if (is_subblock_schedule) {
  106. // <=128
  107. if (threads_per_group == 1) {
  108. LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 1);
  109. } else if (threads_per_group == 2) {
  110. LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 2);
  111. } else if (threads_per_group == 4) {
  112. LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 4);
  113. } else if (threads_per_group == 8) {
  114. LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 8);
  115. } else if (threads_per_group == 16) {
  116. LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 16);
  117. }
  118. } else if (external_unroll == 1) {
  119. // 129 - 4096 elems
  120. // (this can launch with 1-7 warps as well)
  121. LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, internal_unroll, max_threads);
  122. } else if (external_unroll == 2) {
  123. // 4097 - 8192 elems
  124. LAUNCH_CACHED_QUANT(num_bits, quant_type, 2, internal_unroll, max_threads);
  125. } else if (external_unroll == 3) {
  126. // 8193 - 12288 elems
  127. LAUNCH_CACHED_QUANT(num_bits, quant_type, 3, internal_unroll, max_threads);
  128. } else if (external_unroll == 4) {
  129. // 12289 - 16384 elems
  130. LAUNCH_CACHED_QUANT(num_bits, quant_type, 4, internal_unroll, max_threads);
  131. }
  132. }