dequantization_utils.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. /*
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. */
  4. #include "conversion_utils.h"
  5. #include "ds_kernel_utils.h"
  6. #include "quantization.h"
  7. #include "quantization_utils.h"
  8. namespace cg = cooperative_groups;
  9. #pragma once
  10. namespace dequantize {
  11. using Type = quantize::Type;
  12. template <Type qType, int numBits>
  13. using Params = quantize::Params<qType, numBits>;
  14. constexpr int granularity = quantize::granularity;
  15. using PackedInt4 = quantize::PackedInt4;
  16. constexpr int h_per_chunk = granularity / sizeof(__half);
  17. constexpr int h2_per_chunk = granularity / sizeof(__half2);
  18. /*
  19. Device function that reads quantized data from global memory, dequantizes
  20. it, and stores it to global memory.
  21. Template Arguments :
  22. numBits - Number of bits in quantized element. int: 4, 8
  23. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  24. unroll - Number of load steps to internally unroll int
  25. threads - Number of threads to perform dequant int
  26. Function arguments:
  27. global_output - __half pointer in global memory
  28. data - Quantized data in global memory
  29. global_params - Quantization parameters in global memory
  30. elems_per_group - Number of elements in each quantization group
  31. total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
  32. */
  33. template <int numBits, Type qType, int unroll, int threads>
  34. DS_D_INLINE void to_global(__half* global_output,
  35. const int8_t* data,
  36. const float* global_params,
  37. const int elems_per_group,
  38. const int total_elems);
  39. /*
  40. Device function that quantizes 16 bytes of __half type input data.
  41. Template Arguments :
  42. numBits - Number of bits in quantized element. int : 8 or 4
  43. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  44. Function Arguments :
  45. local_output - Local array to store dequantized data __half* or __half2*
  46. data - Pointer to quantized input data. int8_t*
  47. Params - Parameters for quantization. Params<qType, numBits>
  48. */
  49. template <int numBits, Type qType>
  50. DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params);
  51. template <int numBits, Type qType>
  52. DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params);
  53. /**************** Implementations ******************/
  54. template <int numBits, Type qType>
  55. DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params)
  56. {
  57. constexpr int32_t num_elems_packed = 8 / numBits;
  58. constexpr int32_t iters = h_per_chunk / num_elems_packed;
  59. #pragma unroll
  60. for (int i = 0; i < iters; i++) {
  61. if constexpr (num_elems_packed == 1) {
  62. local_output[i] = q_params.dequantize(data[i]);
  63. } else {
  64. auto accessible_data = *(PackedInt4*)(&data[i]);
  65. local_output[2 * i] = q_params.dequantize(accessible_data.low);
  66. local_output[2 * i + 1] = q_params.dequantize(accessible_data.high);
  67. }
  68. }
  69. }
  70. template <int numBits, Type qType>
  71. DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params)
  72. {
  73. __half* local_output_cast = reinterpret_cast<__half*>(local_output);
  74. chunk<numBits>(local_output_cast, data, q_params);
  75. }
  76. template <int numBits, Type qType, int unroll, int threads>
  77. DS_D_INLINE void _to_global(__half* global_output,
  78. const int8_t* data,
  79. const float* global_params,
  80. const int elems_per_group,
  81. const int total_elems)
  82. {
  83. cg::thread_block tb = cg::this_thread_block();
  84. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  85. // Load constants
  86. // TODO(cmikeh2): Refactor into functions?
  87. constexpr int load_granularity = granularity * numBits / 16;
  88. constexpr int load_step_stride = load_granularity * threads;
  89. constexpr int load_block_stride = load_step_stride * unroll;
  90. // Store constants
  91. constexpr int store_step_stride = h_per_chunk * threads;
  92. constexpr int store_block_stride = store_step_stride * unroll;
  93. // Load offsets
  94. const int load_block_offset = tb.group_index().x * load_block_stride;
  95. // Note: we can use `load_granularity` since the dtype is `int8_t`.
  96. const int load_thread_offset = tb.thread_index().x * load_granularity;
  97. const int8_t* load_base = data + load_block_offset + load_thread_offset;
  98. // Store offsets
  99. const int store_block_offset = tb.group_index().x * store_block_stride;
  100. const int store_thread_offset = tb.thread_index().x * h_per_chunk;
  101. const int elem_id_base = store_block_offset + store_thread_offset;
  102. int8_t local_load_buffer[load_granularity * unroll];
  103. __half local_dequant_buffer[h_per_chunk * unroll];
  104. /*
  105. Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
  106. totally clear to me, so this is a deliberately weird code structure.
  107. */
  108. #pragma unroll
  109. for (int i = 0; i < unroll; i++) {
  110. const int elem_id_iter = elem_id_base + i * store_step_stride;
  111. if (elem_id_iter < total_elems) {
  112. mem_access::load_global<load_granularity>(local_load_buffer + i * load_granularity,
  113. load_base + i * load_step_stride);
  114. }
  115. }
  116. #pragma unroll
  117. for (int i = 0; i < unroll; i++) {
  118. const int elem_id_iter = elem_id_base + i * store_step_stride;
  119. if (elem_id_iter < total_elems) {
  120. // TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
  121. // use indexing math to do division free interpolation of the successive groups?
  122. const int group_index = elem_id_iter / elems_per_group;
  123. Params<qType, numBits> q_params(global_params, group_index);
  124. chunk<numBits, qType>(local_dequant_buffer + i * h_per_chunk,
  125. local_load_buffer + i * load_granularity,
  126. q_params);
  127. mem_access::store_global<granularity>(global_output + elem_id_iter,
  128. local_dequant_buffer + i * h_per_chunk);
  129. }
  130. }
  131. }
  132. template <int numBits, Type qType, int unroll, int threads>
  133. DS_D_INLINE void to_global(__half* global_output,
  134. const int8_t* data,
  135. const float* global_params,
  136. const int elems_per_group,
  137. const int total_elems)
  138. {
  139. if constexpr (numBits == 4 || numBits == 8) {
  140. _to_global<numBits, qType, unroll, threads>(
  141. global_output, data, global_params, elems_per_group, total_elems);
  142. } else if constexpr (numBits == 3) {
  143. // TODO(cmikeh2): Need this implementation
  144. assert(false);
  145. } else {
  146. assert(false);
  147. }
  148. }
  149. } // namespace dequantize