quantize.cu 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <stdexcept>
  5. #include "context.h"
  6. #include "memory_access_utils.h"
  7. #include "quantize.h"
  8. #include "reduction_utils.h"
  9. #include <cuda.h>
  10. #include <stdint.h>
  11. #include <cuda_fp16.h>
  12. #include <curand_kernel.h>
  13. #include <cuda_bf16.h>
  14. #include <cuda_runtime_api.h>
  15. using ROp = reduce::ROpType;
  16. namespace quantization {
  17. constexpr int access_granularity = 16;
  18. constexpr int quanitzed_access_granularity = 4;
  19. constexpr int quanitzed_access_granularity_6bits = 2;
  20. constexpr int threads = 256;
  21. constexpr int warps = threads / 32;
  22. } // namespace quantization
  23. template <int _mantisa_bits, int q_mantisa_bits, int stochastic_rounding>
  24. __device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state)
  25. {
  26. constexpr uint32_t mantisa_mask = (1 << (_mantisa_bits - q_mantisa_bits)) - 1;
  27. uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask)
  28. : 1 << (_mantisa_bits - q_mantisa_bits - 1);
  29. mantisa += offset;
  30. dst_exponent += (((mantisa & ~mantisa_mask) == (1 << _mantisa_bits)) ? 1 : 0);
  31. }
  32. template <int _mantisa_bits, int _exponent_bits, int q_mantisa_bits, int q_exponent_bits>
  33. __device__ void clip(uint32_t& exponent, uint32_t& mantisa)
  34. {
  35. constexpr uint32_t max_exponent = (1 << (q_exponent_bits - 1)) + (1 << (_exponent_bits - 1));
  36. constexpr uint32_t min_exponent =
  37. (1 << (_exponent_bits - 1)) - ((1 << (q_exponent_bits - 1)) - 1);
  38. if (exponent > max_exponent) {
  39. exponent = max_exponent;
  40. mantisa = (((uint32_t)-1) >> (32 - q_mantisa_bits)) << 1; //.11 .. 10
  41. }
  42. if (exponent < min_exponent) {
  43. exponent = min_exponent;
  44. mantisa = 0;
  45. }
  46. }
  47. template <typename T,
  48. int unroll,
  49. int _mantisa_bits,
  50. int _exponent_bits,
  51. int total_q_bits = 8,
  52. int q_mantisa_bits = 3,
  53. int stochastic_rounding = 0>
  54. __global__ void apply_quantization(T* val,
  55. uint8_t* q_val,
  56. int group_size,
  57. std::pair<uint64_t, uint64_t> seed,
  58. float q_range)
  59. {
  60. int tidx = threadIdx.x;
  61. int wid = tidx >> 5;
  62. int lane = tidx & 0x1f;
  63. int gid = blockIdx.x * quantization::warps + wid;
  64. constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
  65. constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1;
  66. constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
  67. constexpr uint32_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
  68. // CG helpers
  69. cg::thread_block tb = cg::this_thread_block();
  70. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  71. constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
  72. constexpr uint32_t load_stride = vector_size * hw_warp_size;
  73. constexpr uint32_t store_stride = (total_q_bits * vector_size / 8) * hw_warp_size;
  74. const uint32_t thread_offset = lane * vector_size;
  75. const uint32_t store_thread_offset = lane * (total_q_bits * vector_size / 8);
  76. const uint32_t base_load_offset = gid * group_size + thread_offset;
  77. const uint32_t base_store_offset =
  78. gid * ((group_size * total_q_bits / 8) + 4) +
  79. store_thread_offset; // 4-byte for saving the scale per group
  80. const T* load_base_ptr = val + base_load_offset;
  81. T tmp_buf[unroll * vector_size];
  82. T cur_max;
  83. reduce::init<ROp::Max>(&cur_max);
  84. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  85. curandStatePhilox4_32_10_t state;
  86. curand_init(seed.first, idx, seed.second, &state);
  87. #pragma unroll
  88. for (int i = 0; i < unroll; i++) {
  89. if (i * load_stride + thread_offset < group_size) {
  90. mem_access::load_global<quantization::access_granularity>(
  91. &tmp_buf[vector_size * i], load_base_ptr + i * load_stride);
  92. for (int j = 0; j < vector_size; j++)
  93. cur_max = reduce::element<ROp::Max>(cur_max, __habs(tmp_buf[i * vector_size + j]));
  94. }
  95. }
  96. reduce::_block<T, 1, ROp::Max>(tb, warp, &cur_max);
  97. int mantisa_mask = ((1 << q_mantisa_bits) - 1);
  98. mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
  99. uint8_t* store_base_ptr = q_val + base_store_offset;
  100. float scale = (float)q_range / conversion::to<float>(cur_max);
  101. #pragma unroll
  102. for (int i = 0; i < unroll; i++) {
  103. if (i * load_stride + thread_offset < group_size) {
  104. uint64_t q_buf = 0;
  105. uint64_t q_buf1 = 0;
  106. #pragma unroll
  107. for (int j = 0; j < vector_size; j++) {
  108. float val_f = conversion::to<float>(tmp_buf[i * vector_size + j]) * scale;
  109. uint32_t* data = reinterpret_cast<uint32_t*>(&val_f);
  110. uint32_t sign = (data[0] & _sign_mask) >> (_mantisa_bits + _exponent_bits);
  111. uint32_t cur_exponent = (data[0] & _exponent_mask) >> _mantisa_bits;
  112. uint32_t dst_mantisa = (data[0] & _mantisa_mask);
  113. uint32_t dst_exponent = cur_exponent;
  114. round<_mantisa_bits, q_mantisa_bits, stochastic_rounding>(
  115. dst_mantisa, dst_exponent, &state);
  116. if (cur_exponent != 0)
  117. clip<_mantisa_bits, _exponent_bits, q_mantisa_bits, q_exponent_bits>(
  118. dst_exponent, dst_mantisa);
  119. dst_mantisa = (dst_mantisa & mantisa_mask) >> (_mantisa_bits - q_mantisa_bits);
  120. if (dst_exponent != (1 << q_exponent_bits) - 1)
  121. dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
  122. (1 << (q_exponent_bits - 1)) - 1;
  123. if (total_q_bits == 8 || total_q_bits == 4 || total_q_bits == 6)
  124. q_buf = q_buf |
  125. ((uint64_t)((uint8_t)(sign << (q_exponent_bits + q_mantisa_bits) |
  126. (dst_exponent << q_mantisa_bits) | dst_mantisa))
  127. << j * total_q_bits);
  128. else if (total_q_bits == 12) {
  129. if (j < 5)
  130. q_buf =
  131. q_buf |
  132. ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) |
  133. (dst_exponent << q_mantisa_bits) | dst_mantisa))
  134. << j * total_q_bits);
  135. else
  136. q_buf1 =
  137. q_buf1 |
  138. ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) |
  139. (dst_exponent << q_mantisa_bits) | dst_mantisa))
  140. << (j - 5) * total_q_bits);
  141. }
  142. }
  143. if (total_q_bits == 12) {
  144. uint64_t last_nibble_mask = 0xf;
  145. last_nibble_mask = q_buf1 & last_nibble_mask;
  146. q_buf = (last_nibble_mask << 60) | q_buf;
  147. q_buf1 >>= 4;
  148. }
  149. uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf);
  150. uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf1);
  151. if (total_q_bits == 6) {
  152. mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
  153. store_base_ptr + i * store_stride, int8_data);
  154. mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
  155. store_base_ptr + i * store_stride +
  156. quantization::quanitzed_access_granularity_6bits,
  157. int8_data + quantization::quanitzed_access_granularity_6bits);
  158. mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
  159. store_base_ptr + i * store_stride +
  160. quantization::quanitzed_access_granularity_6bits * 2,
  161. int8_data + 2 * quantization::quanitzed_access_granularity_6bits);
  162. } else {
  163. mem_access::store_global<quantization::quanitzed_access_granularity>(
  164. store_base_ptr + i * store_stride, int8_data);
  165. if (total_q_bits > 4) {
  166. mem_access::store_global<quantization::quanitzed_access_granularity>(
  167. store_base_ptr + i * store_stride +
  168. quantization::quanitzed_access_granularity,
  169. int8_data + quantization::quanitzed_access_granularity);
  170. if (total_q_bits == 12) {
  171. mem_access::store_global<quantization::quanitzed_access_granularity>(
  172. store_base_ptr + i * store_stride +
  173. quantization::quanitzed_access_granularity * 2,
  174. int8_data1);
  175. }
  176. }
  177. }
  178. }
  179. }
  180. if (lane == 0) {
  181. float q_scale = conversion::to<float>(cur_max) / (float)q_range;
  182. uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&q_scale);
  183. uint32_t scale_offset =
  184. gid * ((group_size * total_q_bits / 8) + 4) + (group_size * total_q_bits / 8);
  185. if (total_q_bits != 6)
  186. mem_access::store_global<quantization::quanitzed_access_granularity>(
  187. q_val + scale_offset, scale_as_int8);
  188. else {
  189. mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
  190. q_val + scale_offset, scale_as_int8);
  191. mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
  192. q_val + scale_offset + quantization::quanitzed_access_granularity_6bits,
  193. scale_as_int8 + quantization::quanitzed_access_granularity_6bits);
  194. }
  195. }
  196. }
  197. template <typename T,
  198. int q_mantisa_bits,
  199. int total_q_bits = 16,
  200. int _mantisa_bits = 3,
  201. int _exponent_bits = 4>
  202. __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
  203. {
  204. constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
  205. int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;
  206. constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
  207. constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
  208. constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
  209. constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
  210. constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
  211. const uint32_t g_index = (tidx / group_size);
  212. const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
  213. const uint8_t* load_base_ptr =
  214. val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8;
  215. int mantisa_mask = ((1 << q_mantisa_bits) - 1);
  216. mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
  217. T* store_base_ptr = q_val + tidx;
  218. float scale;
  219. uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
  220. if (quantized_bits == 6) {
  221. mem_access::load_global<quantization::quanitzed_access_granularity>(
  222. scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
  223. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  224. scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
  225. val + g_index * (group_size_bytes + 4) + group_size_bytes +
  226. quantization::quanitzed_access_granularity_6bits);
  227. } else
  228. mem_access::load_global<quantization::quanitzed_access_granularity>(
  229. scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
  230. if (tidx < total_num_elements) {
  231. uint64_t q_buf_in;
  232. uint64_t q_buf_in1;
  233. uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
  234. uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
  235. if (quantized_bits == 6) {
  236. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  237. int8_data, load_base_ptr);
  238. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  239. int8_data + quantization::quanitzed_access_granularity_6bits,
  240. load_base_ptr + quantization::quanitzed_access_granularity_6bits);
  241. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  242. int8_data + quantization::quanitzed_access_granularity_6bits * 2,
  243. load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
  244. } else {
  245. mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
  246. load_base_ptr);
  247. if (quantized_bits > 4) {
  248. mem_access::load_global<quantization::quanitzed_access_granularity>(
  249. int8_data + quantization::quanitzed_access_granularity,
  250. load_base_ptr + quantization::quanitzed_access_granularity);
  251. if (quantized_bits == 12) {
  252. mem_access::load_global<quantization::quanitzed_access_granularity>(
  253. int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
  254. }
  255. }
  256. }
  257. T store_buf[vector_size];
  258. uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
  259. #pragma unroll
  260. for (int j = 0; j < vector_size; j++) {
  261. uint16_t new_data;
  262. if (j < 5 || quantized_bits != 12) {
  263. new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
  264. } else {
  265. if (j == 5) {
  266. new_data = (uint16_t)(q_buf_in1);
  267. new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
  268. } else
  269. new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
  270. }
  271. uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
  272. uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
  273. uint16_t dst_mantisa = (new_data & _mantisa_mask);
  274. if (dst_exponent != (1 << q_exponent_bits) - 1)
  275. dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
  276. (1 << (q_exponent_bits - 1)) - 1;
  277. q_buf[j] =
  278. ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
  279. (dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
  280. float up_cast = conversion::to<float>(store_buf[j]);
  281. store_buf[j] = conversion::to<T>(up_cast * scale);
  282. }
  283. mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
  284. }
  285. }
  286. #define LAUNCH_FOR_QUANTIZATION_UNROLL(COUNT) \
  287. case COUNT: \
  288. apply_quantization<T, \
  289. COUNT, \
  290. mantisa, \
  291. exponent, \
  292. CONST_Q_BITS, \
  293. CONST_Q_MANTISA_BITS, \
  294. CONST_STOCHASTIC_ROUNDING> \
  295. <<<grid, block, 0, stream>>>(val, q_val, group_size, seed, q_range); \
  296. break;
  297. template <typename T, int mantisa, int exponent>
  298. void launch_quantization(T* val,
  299. uint8_t* q_val,
  300. int num_groups,
  301. int group_size,
  302. cudaStream_t stream,
  303. float q_range,
  304. int q_bits,
  305. int q_mantisa_bits,
  306. int stochastic_rounding)
  307. {
  308. const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
  309. const dim3 block(quantization::threads);
  310. std::pair<uint64_t, uint64_t> seed = FPContext::Instance().IncrementOffset(16);
  311. constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
  312. const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;
  313. QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] {
  314. switch (copy_unroll) {
  315. LAUNCH_FOR_QUANTIZATION_UNROLL(1)
  316. LAUNCH_FOR_QUANTIZATION_UNROLL(2)
  317. LAUNCH_FOR_QUANTIZATION_UNROLL(3)
  318. LAUNCH_FOR_QUANTIZATION_UNROLL(4)
  319. LAUNCH_FOR_QUANTIZATION_UNROLL(5)
  320. LAUNCH_FOR_QUANTIZATION_UNROLL(6)
  321. }
  322. });
  323. }
  324. #define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \
  325. template void launch_quantization<T, mantisa, exponent>( \
  326. T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int);
  327. // fp8(E4M3), nearest-rounding
  328. #ifdef BF16_AVAILABLE
  329. INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
  330. #endif
  331. INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);
  332. template <typename T, int mantisa>
  333. void launch_dequantization(uint8_t* val,
  334. T* q_val,
  335. int num_groups,
  336. int group_size,
  337. int q_mantisa_bits,
  338. int q_exponent_bits,
  339. cudaStream_t stream)
  340. {
  341. int blocks = ((num_groups * group_size) - 1) /
  342. (quantization::threads * (quantization::access_granularity / sizeof(T))) +
  343. 1;
  344. const dim3 grid(blocks);
  345. const dim3 block(quantization::threads);
  346. DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
  347. apply_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
  348. <<<grid, block, 0, stream>>>(val, q_val, group_size, (num_groups * group_size));
  349. });
  350. }
  351. #define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
  352. template void launch_dequantization<T, mantisa>(uint8_t*, T*, int, int, int, int, cudaStream_t);
  353. // fp8(E4M3)
  354. #ifdef BF16_AVAILABLE
  355. INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
  356. #endif
  357. INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);
  358. template <typename T,
  359. int q_mantisa_bits,
  360. int total_q_bits = 16,
  361. int _mantisa_bits = 3,
  362. int _exponent_bits = 4>
  363. __global__ void apply_selective_dequantization(uint8_t* val,
  364. T* q_val,
  365. int32_t* indexes,
  366. int group_size,
  367. int total_num_elements)
  368. {
  369. int index = indexes[blockIdx.x];
  370. constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
  371. int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
  372. int input_index = index * total_num_elements + tidx;
  373. constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
  374. constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
  375. constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
  376. constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
  377. constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
  378. const uint32_t g_index = (input_index / group_size);
  379. const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
  380. const uint8_t* load_base_ptr =
  381. val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8;
  382. int mantisa_mask = ((1 << q_mantisa_bits) - 1);
  383. mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
  384. T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements;
  385. float scale;
  386. uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
  387. if (quantized_bits == 6) {
  388. mem_access::load_global<quantization::quanitzed_access_granularity>(
  389. scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
  390. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  391. scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
  392. val + g_index * (group_size_bytes + 4) + group_size_bytes +
  393. quantization::quanitzed_access_granularity_6bits);
  394. } else
  395. mem_access::load_global<quantization::quanitzed_access_granularity>(
  396. scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
  397. if (tidx < total_num_elements) {
  398. uint64_t q_buf_in;
  399. uint64_t q_buf_in1;
  400. uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
  401. uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
  402. if (quantized_bits == 6) {
  403. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  404. int8_data, load_base_ptr);
  405. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  406. int8_data + quantization::quanitzed_access_granularity_6bits,
  407. load_base_ptr + quantization::quanitzed_access_granularity_6bits);
  408. mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
  409. int8_data + quantization::quanitzed_access_granularity_6bits * 2,
  410. load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
  411. } else {
  412. mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
  413. load_base_ptr);
  414. if (quantized_bits > 4) {
  415. mem_access::load_global<quantization::quanitzed_access_granularity>(
  416. int8_data + quantization::quanitzed_access_granularity,
  417. load_base_ptr + quantization::quanitzed_access_granularity);
  418. if (quantized_bits == 12) {
  419. mem_access::load_global<quantization::quanitzed_access_granularity>(
  420. int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
  421. }
  422. }
  423. }
  424. T store_buf[vector_size];
  425. uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
  426. #pragma unroll
  427. for (int j = 0; j < vector_size; j++) {
  428. uint16_t new_data;
  429. if (j < 5 || quantized_bits != 12) {
  430. new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
  431. } else {
  432. if (j == 5) {
  433. new_data = (uint16_t)(q_buf_in1);
  434. new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
  435. } else
  436. new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
  437. }
  438. uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
  439. uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
  440. uint16_t dst_mantisa = (new_data & _mantisa_mask);
  441. if (dst_exponent != (1 << q_exponent_bits) - 1)
  442. dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
  443. (1 << (q_exponent_bits - 1)) - 1;
  444. q_buf[j] =
  445. ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
  446. (dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
  447. float up_cast = conversion::to<float>(store_buf[j]);
  448. store_buf[j] = conversion::to<T>(up_cast * scale);
  449. }
  450. mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
  451. }
  452. }
  453. template <typename T, int mantisa>
  454. void launch_selective_dequantization(uint8_t* val,
  455. T* q_val,
  456. int32_t* indexes,
  457. int num_groups,
  458. int group_size,
  459. int num_indexes,
  460. int q_mantisa_bits,
  461. int q_exponent_bits,
  462. cudaStream_t stream)
  463. {
  464. int total_elements_per_index = (num_groups / num_indexes) * group_size;
  465. int blocks = (total_elements_per_index - 1) /
  466. (quantization::threads * (quantization::access_granularity / sizeof(T))) +
  467. 1;
  468. const dim3 grid(num_indexes, blocks);
  469. const dim3 block(quantization::threads);
  470. DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
  471. apply_selective_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
  472. <<<grid, block, 0, stream>>>(val, q_val, indexes, group_size, total_elements_per_index);
  473. });
  474. }
  475. #define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \
  476. template void launch_selective_dequantization<T, mantisa>( \
  477. uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t);
  478. // fp8(E4M3)
  479. #ifdef BF16_AVAILABLE
  480. INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7);
  481. #endif
  482. INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10);