dequantization_utils.h 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "conversion_utils.h"
  5. #include "ds_kernel_utils.h"
  6. #include "quantization.h"
  7. #include "quantization_utils.h"
  8. namespace cg = cooperative_groups;
  9. #pragma once
  10. namespace dequantize {
  11. using Type = quantize::Type;
  12. template <Type qType, int numBits>
  13. using Params = quantize::Params<qType, numBits>;
  14. constexpr int granularity = quantize::granularity;
  15. using PackedInt4 = quantize::PackedInt4;
  16. constexpr int h_per_chunk = granularity / sizeof(__half);
  17. constexpr int h2_per_chunk = granularity / sizeof(__half2);
  18. /*
  19. Device function that reads quantized data from global memory, dequantizes
  20. it, and stores it to global memory.
  21. Template Arguments :
  22. numBits - Number of bits in quantized element. int: 4, 8
  23. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  24. unroll - Number of load steps to internally unroll int
  25. threads - Number of threads to perform dequant int
  26. Function arguments:
  27. global_output - __half pointer in global memory
  28. data - Quantized data in global memory
  29. global_params - Quantization parameters in global memory
  30. elems_per_group - Number of elements in each quantization group
  31. total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
  32. */
  33. template <int numBits, Type qType, int unroll, int threads>
  34. DS_D_INLINE void to_global(__half* global_output,
  35. const int8_t* data,
  36. const float* global_params,
  37. const int elems_per_group,
  38. const int total_elems);
  39. /*
  40. Device function that quantizes 16 bytes of __half type input data.
  41. Template Arguments :
  42. numBits - Number of bits in quantized element. int : 8 or 4
  43. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  44. Function Arguments :
  45. local_output - Local array to store dequantized data __half* or __half2*
  46. data - Pointer to quantized input data. int8_t*
  47. Params - Parameters for quantization. Params<qType, numBits>
  48. */
  49. template <int numBits, Type qType>
  50. DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params);
  51. template <typename T, int numBits, Type qType>
  52. DS_D_INLINE void chunk(T* local_output, const int8_t* data, Params<qType, numBits> q_params);
  53. /**************** Implementations ******************/
  54. template <typename T, int numBits, Type qType>
  55. DS_D_INLINE void chunk(T* local_output, const int8_t* data, Params<qType, numBits> q_params)
  56. {
  57. constexpr int32_t num_elems_packed = 8 / numBits;
  58. constexpr int32_t iters = h_per_chunk / num_elems_packed;
  59. #pragma unroll
  60. for (int i = 0; i < iters; i++) {
  61. if constexpr (num_elems_packed == 1) {
  62. local_output[i] = q_params.template dequantize<T>(data[i]);
  63. } else {
  64. auto accessible_data = *(PackedInt4*)(&data[i]);
  65. local_output[2 * i] = q_params.template dequantize<T>(accessible_data.low);
  66. local_output[2 * i + 1] = q_params.template dequantize<T>(accessible_data.high);
  67. }
  68. }
  69. }
  70. template <int numBits, Type qType>
  71. DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params)
  72. {
  73. __half* local_output_cast = reinterpret_cast<__half*>(local_output);
  74. chunk<__half, numBits>(local_output_cast, data, q_params);
  75. }
  76. template <typename T, int numBits, Type qType, int unroll, int threads>
  77. DS_D_INLINE void _to_global(T* global_output,
  78. const int8_t* data,
  79. const float* global_params,
  80. const int elems_per_group,
  81. const int total_elems)
  82. {
  83. cg::thread_block tb = cg::this_thread_block();
  84. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  85. // Load constants
  86. // TODO(cmikeh2): Refactor into functions?
  87. constexpr int load_granularity = (granularity / (sizeof(T))) / (numBits == 8 ? 1 : 2);
  88. constexpr int load_step_stride = load_granularity * threads;
  89. constexpr int load_block_stride = load_step_stride * unroll;
  90. // Store constants
  91. constexpr int T_per_chunk = granularity / sizeof(T);
  92. constexpr int store_step_stride = T_per_chunk * threads;
  93. constexpr int store_block_stride = store_step_stride * unroll;
  94. // Load offsets
  95. const int load_block_offset = tb.group_index().x * load_block_stride;
  96. // Note: we can use `load_granularity` since the dtype is `int8_t`.
  97. const int load_thread_offset = tb.thread_index().x * load_granularity;
  98. const int8_t* load_base = data + load_block_offset + load_thread_offset;
  99. // Store offsets
  100. const int store_block_offset = tb.group_index().x * store_block_stride;
  101. const int store_thread_offset = tb.thread_index().x * T_per_chunk;
  102. const int elem_id_base = store_block_offset + store_thread_offset;
  103. int8_t local_load_buffer[load_granularity * unroll];
  104. T local_dequant_buffer[T_per_chunk * unroll];
  105. /*
  106. Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
  107. totally clear to me, so this is a deliberately weird code structure.
  108. */
  109. #pragma unroll
  110. for (int i = 0; i < unroll; i++) {
  111. const int elem_id_iter = elem_id_base + i * store_step_stride;
  112. if (elem_id_iter < total_elems) {
  113. mem_access::load_global<load_granularity>(local_load_buffer + i * load_granularity,
  114. load_base + i * load_step_stride);
  115. }
  116. }
  117. #pragma unroll
  118. for (int i = 0; i < unroll; i++) {
  119. const int elem_id_iter = elem_id_base + i * store_step_stride;
  120. if (elem_id_iter < total_elems) {
  121. // TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
  122. // use indexing math to do division free interpolation of the successive groups?
  123. const int group_index = elem_id_iter / elems_per_group;
  124. Params<qType, numBits> q_params(global_params, group_index);
  125. chunk<T, numBits, qType>(local_dequant_buffer + i * T_per_chunk,
  126. local_load_buffer + i * load_granularity,
  127. q_params);
  128. mem_access::store_global<granularity>(global_output + elem_id_iter,
  129. local_dequant_buffer + i * T_per_chunk);
  130. }
  131. }
  132. }
  133. template <typename T, int numBits, Type qType, int unroll, int threads>
  134. DS_D_INLINE void to_global(T* global_output,
  135. const int8_t* data,
  136. const float* global_params,
  137. const int elems_per_group,
  138. const int total_elems)
  139. {
  140. if constexpr (numBits == 4 || numBits == 8) {
  141. _to_global<T, numBits, qType, unroll, threads>(
  142. global_output, data, global_params, elems_per_group, total_elems);
  143. } else if constexpr (numBits == 3) {
  144. // TODO(cmikeh2): Need this implementation
  145. assert(false);
  146. } else {
  147. assert(false);
  148. }
  149. }
  150. } // namespace dequantize