quant_reduce.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <cstdio>
  5. #include "dequantization_utils.h"
  6. #include "ds_kernel_utils.h"
  7. #include "memory_access_utils.h"
  8. #include "quantization_utils.h"
  9. #include "reduction_utils.h"
  10. using rop = reduce::ROpType;
  11. /*
  12. TODO(cmikeh2): Add implementation that better handles larger nodes. It would like make sense
  13. to leverage some parallel reductions here to improve performance.
  14. */
  15. template <int numBits, int numTensors, int totalChunks, quantize::Type quantType>
  16. __global__ void __launch_bounds__(1024) dequant_reduce(int8_t* reduced_data,
  17. float* reduced_scales,
  18. const int8_t* input_data,
  19. const float* input_scales,
  20. int elems_per_out_group,
  21. int elems_per_in_tensor,
  22. int groups_per_in_tensor,
  23. int elems_per_in_group,
  24. int num_tensors)
  25. {
  26. cg::thread_block tb = cg::this_thread_block();
  27. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  28. // NOTE(cmikeh2): This probably could be hardcoded to a larger number,
  29. // but that means even stronger restrictions on the number of elements per group
  30. // A performance analysis here might be beneficial
  31. constexpr int mem_granularity = (numBits == 8) ? 8 : 4;
  32. constexpr int elems_per_load = mem_granularity / sizeof(int8_t); // div by 1
  33. constexpr int storage_values = 16 / sizeof(__half2);
  34. const int block_offset = tb.group_index().x * elems_per_out_group;
  35. const int elem_offset = tb.thread_index().x * elems_per_load;
  36. const int base_offset = block_offset + elem_offset;
  37. const int stride = tb.group_dim().x * elems_per_load;
  38. __half2 local_buffer[totalChunks * storage_values];
  39. quantize::GroupStats<quantType> stats;
  40. #pragma unroll
  41. for (int i = 0; i < totalChunks; i++) {
  42. __half2* iteration_buffer = local_buffer + i * storage_values;
  43. #pragma unroll
  44. for (int j = 0; j < storage_values; j++) {
  45. iteration_buffer[j] = reduce::init<rop::Add, __half2>();
  46. }
  47. const int iter_offset = i * stride + base_offset;
  48. const int iter_scale_idx = iter_offset / elems_per_in_group;
  49. bool do_loads = i * stride + elem_offset < elems_per_out_group;
  50. if (numTensors > 0) {
  51. #pragma unroll
  52. for (int j = 0; j < numTensors; j++) {
  53. if (do_loads) {
  54. int8_t load_buffer[elems_per_load];
  55. mem_access::load_global<mem_granularity>(
  56. load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
  57. quantize::Params<quantType, numBits> params(
  58. input_scales + j * groups_per_in_tensor, iter_scale_idx);
  59. __half2 dequant_buffer[storage_values];
  60. dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
  61. #pragma unroll
  62. for (int k = 0; k < storage_values; k++) {
  63. iteration_buffer[k] =
  64. reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
  65. }
  66. }
  67. }
  68. } else {
  69. #pragma unroll 4
  70. for (int j = 0; j < num_tensors; j++) {
  71. if (do_loads) {
  72. int8_t load_buffer[elems_per_load];
  73. mem_access::load_global<mem_granularity>(
  74. load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
  75. quantize::Params<quantType, numBits> params(
  76. input_scales + j * groups_per_in_tensor, iter_scale_idx);
  77. __half2 dequant_buffer[storage_values];
  78. dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
  79. #pragma unroll
  80. for (int k = 0; k < storage_values; k++) {
  81. iteration_buffer[k] =
  82. reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
  83. }
  84. }
  85. }
  86. }
  87. #pragma unroll
  88. for (int j = 0; j < storage_values; j++) { stats.update(iteration_buffer[j]); }
  89. }
  90. auto params = stats.template get_params<numBits, 1024>(tb, warp);
  91. if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }
  92. #pragma unroll
  93. for (int i = 0; i < totalChunks; i++) {
  94. const int iter_offset = i * stride + base_offset;
  95. if (i * stride + elem_offset < elems_per_out_group) {
  96. int8_t local_output[elems_per_load];
  97. quantize::_chunk<numBits, quantType>(
  98. local_output, local_buffer + i * storage_values, params);
  99. mem_access::store_global<mem_granularity>(reduced_data + iter_offset, local_output);
  100. }
  101. }
  102. }
  103. template <int Power>
  104. int32_t pow2_round(int32_t raw_value)
  105. {
  106. return (((raw_value - 1) >> Power) + 1) << Power;
  107. }
  108. #define LAUNCH_DEQUANT_REDUCE(num_chunks) \
  109. dequant_reduce<numBits, numTensors, num_chunks, quantType> \
  110. <<<grid, block, 0, stream>>>(reduced_data, \
  111. reduced_scales, \
  112. input_data, \
  113. input_scales, \
  114. elems_per_out_group, \
  115. elems_per_in_tensor, \
  116. groups_per_in_tensor, \
  117. elems_per_in_group, \
  118. num_tensors);
  119. template <int numBits, int numTensors, quantize::Type quantType>
  120. void launch_dequant_reduce_impl(int8_t* reduced_data,
  121. float* reduced_scales,
  122. const int8_t* input_data,
  123. const float* input_scales,
  124. int out_groups,
  125. int elems_per_out_group,
  126. int elems_per_in_tensor,
  127. int groups_per_in_tensor,
  128. int elems_per_in_group,
  129. int num_tensors,
  130. cudaStream_t stream)
  131. {
  132. // This is a coincidence. This is derived by 8 halves per 16 bytes with 2-way packing for int4
  133. constexpr int elems_per_thread = numBits;
  134. const int one_step_threads =
  135. next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread));
  136. // TODO(cmikeh2): Tune this
  137. const int threads = (one_step_threads < 1024) ? one_step_threads : 1024;
  138. dim3 block(threads);
  139. dim3 grid(out_groups);
  140. const int elems_per_step = threads * elems_per_thread;
  141. const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step;
  142. const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw;
  143. if (unroll == 1) {
  144. // 0-4096 elems
  145. LAUNCH_DEQUANT_REDUCE(1);
  146. } else if (unroll == 2) {
  147. // 4097-8192 etc...
  148. LAUNCH_DEQUANT_REDUCE(2);
  149. } else if (unroll == 3) {
  150. LAUNCH_DEQUANT_REDUCE(3);
  151. } else if (unroll == 4) {
  152. LAUNCH_DEQUANT_REDUCE(4);
  153. } else if (unroll == 6) {
  154. LAUNCH_DEQUANT_REDUCE(6);
  155. } else if (unroll == 8) {
  156. LAUNCH_DEQUANT_REDUCE(8);
  157. } else if (unroll == 10) {
  158. LAUNCH_DEQUANT_REDUCE(10);
  159. } else if (unroll == 12) {
  160. // 48k limit
  161. LAUNCH_DEQUANT_REDUCE(12);
  162. } else {
  163. assert(false);
  164. }
  165. }
  166. #define LAUNCH_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
  167. launch_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
  168. reduced_scales, \
  169. input_data, \
  170. input_scales, \
  171. out_groups, \
  172. elems_per_out_group, \
  173. elems_per_in_tensor, \
  174. groups_per_in_tensor, \
  175. elems_per_in_group, \
  176. num_gpus, \
  177. stream);
  178. void launch_dequant_reduce(int8_t* reduced_data,
  179. float* reduced_scales,
  180. const int8_t* input_data,
  181. const float* input_scales,
  182. int num_gpus,
  183. int num_bits,
  184. quantize::Type quant_type,
  185. int out_groups,
  186. int elems_per_out_group,
  187. int elems_per_in_tensor,
  188. int groups_per_in_tensor,
  189. int elems_per_in_group,
  190. cudaStream_t stream)
  191. {
  192. if (quant_type == quantize::Type::Symmetric) {
  193. if (num_bits == 4) {
  194. if (num_gpus == 8) {
  195. LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric);
  196. } else if (num_gpus == 16) {
  197. LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric);
  198. } else {
  199. LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric);
  200. }
  201. } else if (num_bits == 8) {
  202. if (num_gpus == 8) {
  203. LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric);
  204. } else if (num_gpus == 16) {
  205. LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric);
  206. } else {
  207. LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric);
  208. }
  209. }
  210. } else if (quant_type == quantize::Type::Asymmetric) {
  211. if (num_bits == 4) {
  212. if (num_gpus == 8) {
  213. LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric);
  214. } else if (num_gpus == 16) {
  215. LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric);
  216. } else {
  217. LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric);
  218. }
  219. } else if (num_bits == 8) {
  220. if (num_gpus == 8) {
  221. LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric);
  222. } else if (num_gpus == 16) {
  223. LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric);
  224. } else {
  225. LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric);
  226. }
  227. }
  228. }
  229. }