quantization_utils.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <cassert>
  5. #include "conversion_utils.h"
  6. #include "ds_kernel_utils.h"
  7. #include "memory_access_utils.h"
  8. #include "quantization.h"
  9. #include "reduction_utils.h"
  10. #pragma once
  11. using rop = reduce::ROpType;
  12. namespace quantize {
  13. constexpr int granularity = 16;
  14. constexpr int h_per_load = granularity / sizeof(__half);
  15. constexpr int h2_per_load = granularity / sizeof(__half2);
  16. constexpr int max_threads = 1024;
  17. /*
  18. Class to hold the quantization parameters for a given tensor.
  19. Holds the implementation of the quantization operation.
  20. */
  21. template <Type qType, int numBits>
  22. class Params {
  23. public:
  24. /*
  25. Quantization implementation, supports
  26. 1) 4 Bit
  27. 2) 8 Bit
  28. 3) Symmetric
  29. 4) Asymmetric
  30. Function Arguments :
  31. val : The __half value to quantize.
  32. */
  33. DS_D_INLINE int8_t quantize(__half val);
  34. template <typename T>
  35. DS_D_INLINE T dequantize(int8_t val);
  36. DS_D_INLINE void store(float* params, int group_index);
  37. // Initialize from memory
  38. DS_D_INLINE Params(const float* params, int group_index);
  39. };
  40. template <int numBits>
  41. class Params<Type::Symmetric, numBits> {
  42. public:
  43. float scale;
  44. DS_D_INLINE Params(float max)
  45. {
  46. if (max == 0) {
  47. scale = 1.0;
  48. } else {
  49. scale = (1 << numBits) / (2 * max);
  50. }
  51. }
  52. DS_D_INLINE int8_t quantize(__half val)
  53. {
  54. constexpr int32_t q_min = -(1 << (numBits - 1));
  55. constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
  56. float val_f = conversion::to<float>(val) * scale;
  57. int32_t data_i32 = conversion::to<int32_t>(val_f);
  58. data_i32 = min(max(data_i32, q_min), q_max);
  59. return (int8_t)data_i32;
  60. }
  61. template <typename T>
  62. DS_D_INLINE T dequantize(int8_t val)
  63. {
  64. const float val_deq_f = conversion::to<float>(val) * scale;
  65. return conversion::to<T>(val_deq_f);
  66. }
  67. DS_D_INLINE void store(float* params, int group_index)
  68. {
  69. const float store_scale = 1 / scale;
  70. mem_access::store_global<sizeof(float)>(params + group_index, &store_scale);
  71. }
  72. DS_D_INLINE Params(const float* params, int group_index)
  73. {
  74. mem_access::load_global<sizeof(float)>(&scale, params + group_index);
  75. }
  76. };
  77. template <int numBits>
  78. class Params<Type::Asymmetric, numBits> {
  79. public:
  80. float scale;
  81. float offset;
  82. DS_D_INLINE Params(float max, float min)
  83. {
  84. if (max == min) {
  85. scale = 1.0;
  86. } else {
  87. scale = ((1 << numBits)) / (max - min);
  88. }
  89. offset = (max + min) / 2;
  90. }
  91. DS_D_INLINE int8_t quantize(__half val)
  92. {
  93. constexpr int32_t q_min = -(1 << (numBits - 1));
  94. constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
  95. float val_f = (conversion::to<float>(val) - offset) * scale;
  96. int32_t data_i32 = conversion::to<int32_t>(val_f);
  97. data_i32 = min(max(data_i32, q_min), q_max);
  98. return (int8_t)data_i32;
  99. }
  100. template <typename T>
  101. DS_D_INLINE T dequantize(int8_t val)
  102. {
  103. const float val_deq_f = ((conversion::to<float>(val)) * scale) + offset;
  104. return conversion::to<__half>(val_deq_f);
  105. }
  106. DS_D_INLINE void store(float* params, int group_index)
  107. {
  108. // Codegen should turn this into stg.64
  109. const float store_scale = 1 / scale;
  110. mem_access::store_global<sizeof(float)>(params + 2 * group_index, &store_scale);
  111. mem_access::store_global<sizeof(float)>(params + 2 * group_index + 1, &offset);
  112. }
  113. DS_D_INLINE Params(const float* params, int group_index)
  114. {
  115. // Codegen should turn this into ldg.64
  116. mem_access::load_global<sizeof(float)>(&scale, params + 2 * group_index);
  117. mem_access::load_global<sizeof(float)>(&offset, params + 2 * group_index + 1);
  118. }
  119. };
  120. /*
  121. Group stats tracks the necessary statistics about the quantized group
  122. to abstract the particulars for the main loop.
  123. */
  124. template <Type qType>
  125. class GroupStats {
  126. public:
  127. DS_D_INLINE void update(__half2 val);
  128. DS_D_INLINE void reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp);
  129. };
  130. template <>
  131. class GroupStats<Type::Symmetric> {
  132. public:
  133. // Symmetric quantization only tracks the maximum absolute value
  134. __half2 cur_max;
  135. float max;
  136. /*
  137. Technically, this would give bad results if there
  138. are 0 values to process since the reduction would
  139. give -inf instead of 0. We do not consider this
  140. to be a reasonable edge case.
  141. */
  142. DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, __half2>(); }
  143. /*
  144. Updated the running absmax used to calculate params.
  145. Function Arguments :
  146. val : The __half2 value to update the running min and max with.
  147. */
  148. DS_D_INLINE void update(__half2 val)
  149. {
  150. cur_max = reduce::element<rop::Max>(cur_max, __habs2(val));
  151. }
  152. /*
  153. Function to return calculated quantization params.
  154. Template Arguments :
  155. numBits - Number of bits in quantized element. int : 8 or 4
  156. Function Arguments :
  157. tb - Threadblock object. cg::thread_block
  158. warp - Warp object. cg::thread_block_tile<hw_warp_size>
  159. */
  160. template <int numBits, int threads_per_group>
  161. DS_D_INLINE Params<Type::Symmetric, numBits> get_params(
  162. cg::thread_block& tb,
  163. cg::thread_block_tile<hw_warp_size>& warp)
  164. {
  165. const float2 partial_max = conversion::to<float2>(cur_max);
  166. float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
  167. reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
  168. Params<Type::Symmetric, numBits> params(max);
  169. return params;
  170. }
  171. };
  172. template <>
  173. class GroupStats<Type::Asymmetric> {
  174. public:
  175. __half2 cur_max;
  176. __half2 cur_min;
  177. /*
  178. Initialize cur_max to -inf, cur_min to inf since
  179. we are doing a true range analysis.
  180. */
  181. DS_D_INLINE GroupStats()
  182. {
  183. cur_max = reduce::init<rop::Max, __half2>();
  184. cur_min = reduce::init<rop::Min, __half2>();
  185. }
  186. /*
  187. Updated the running min and max used to calculate params.
  188. Function Arguments :
  189. val : The __half2 value to update the running min and max with.
  190. */
  191. DS_D_INLINE void update(__half2 val)
  192. {
  193. cur_max = reduce::element<rop::Max>(cur_max, val);
  194. cur_min = reduce::element<rop::Min>(cur_min, val);
  195. }
  196. /*
  197. Function to return calculated quantization params.
  198. Template Arguments :
  199. numBits - Number of bits in quantized element. int : 8 or 4
  200. Function Arguments :
  201. tb - Threadblock object. cg::thread_block
  202. warp - Warp object. cg::thread_block_tile<hw_warp_size>
  203. */
  204. template <int numBits, int threads_per_group>
  205. DS_D_INLINE Params<Type::Asymmetric, numBits> get_params(
  206. cg::thread_block& tb,
  207. cg::thread_block_tile<hw_warp_size>& warp)
  208. {
  209. const float2 partial_max = conversion::to<float2>(cur_max);
  210. float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
  211. const float2 partial_min = conversion::to<float2>(cur_min);
  212. float min = reduce::element<rop::Min>(partial_min.x, partial_min.y);
  213. reduce::partitioned_block<rop::Max, rop::Min, threads_per_group>(tb, warp, max, min);
  214. Params<Type::Asymmetric, numBits> params(max, min);
  215. return params;
  216. }
  217. };
  218. /*
  219. Device function that quantizes 16 bytes of __half type input data.
  220. Template Arguments :
  221. numBits - Number of bits in quantized element. int : 8 or 4
  222. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  223. Function Arguments :
  224. local_output - Pointer to local memory to store quantized data. int8_t*
  225. data - Pointer to input data. __half*
  226. Params - Parameters for quantization. Params<qType, numBits>
  227. */
  228. template <int numBits, Type qType>
  229. DS_D_INLINE void _chunk(int8_t* local_output, const __half* data, Params<qType, numBits> q_params);
  230. /*
  231. Device function that quantizes 16 bytes of __half2 type input data.
  232. Template Arguments :
  233. numBits - Number of bits in quantized element. int : 8 or 4
  234. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  235. Function Arguments :
  236. local_output - Pointer to local memory to store quantized data. int8_t*
  237. data - Pointer to input data. __half2*
  238. Params - Parameters for quantization. Params<qType, numBits>
  239. */
  240. template <int numBits, Type qType>
  241. DS_D_INLINE void _chunk(int8_t* local_output, const __half2* data, Params<qType, numBits> q_params);
  242. /*
  243. Helper function to do serial reduction on register-file arrays.
  244. Template Arguments :
  245. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  246. numChunks - Number of bits in quantized element. int : 8 or 4
  247. Function Arguments :
  248. local_buffer - Pointer memory with input half2 data to be quantized.
  249. */
  250. template <Type qType, int numChunks>
  251. DS_D_INLINE GroupStats<qType> _local_serial_reduce(__half2* local_buffer);
  252. /*
  253. The main loop of the kernel that quantizes array in local memory of __half2 type input data, when
  254. Quantization parameters are pre-computed.
  255. Template Arguments :
  256. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  257. numBits - Number of bits in quantized element. int : 8 or 4
  258. numChunks - Number of chunks(16 bytes of Input data). int : 8 or 4
  259. Function Arguments :
  260. local_buffer - Pointer memory with input half2 data to be quantized.
  261. scales - Pointer to output scales.
  262. offsets - Pointer to output offsets.
  263. output_data - Pointer to output data.
  264. elems_per_group - Number of elements to quantize in a group.
  265. q_params - Quantization parameters.
  266. */
  267. template <int numBits, Type qType, int numChunks, int threads_per_group, int max_threads>
  268. DS_D_INLINE void local_array(cg::thread_block& tb,
  269. cg::thread_block_tile<hw_warp_size>& warp,
  270. __half2* local_buffer,
  271. float* __restrict__ scales,
  272. float* __restrict__ offsets,
  273. int8_t* __restrict__ output_data,
  274. const int& elems_per_group,
  275. const int& groups,
  276. Params<qType, numBits> q_params);
  277. /*
  278. The main loop of the kernel that quantizes array in local memory of __half2 type input data.
  279. This function computes quantization parameters for each group.
  280. Template Arguments :
  281. qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
  282. numBits - Number of bits in quantized element. int : 8 or 4
  283. numChunks - Number of chunks(16 bytes of Input data). int : 8 or 4
  284. Function Arguments :
  285. local_buffer - Pointer memory with input half2 data to be quantized.
  286. scales - Pointer to output scales.
  287. offsets - Pointer to output offsets.
  288. output_data - Pointer to output data.
  289. elems_per_group - Number of elements to quantize in a group.
  290. */
  291. template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
  292. __device__ void local_array(__half2* local_buffer,
  293. float* __restrict__ scales,
  294. float* __restrict__ offsets,
  295. int8_t* __restrict__ output_data,
  296. const int& elems_per_group,
  297. const int& groups);
  298. template <int numBits, Type qType>
  299. DS_D_INLINE void _chunk(int8_t* local_output, const __half* data, Params<qType, numBits> q_params)
  300. {
  301. constexpr int32_t elems = 16 / sizeof(__half);
  302. constexpr int32_t num_elems_packed = 8 / numBits;
  303. #pragma unroll
  304. for (int i = 0, oi = 0; i < elems; i += num_elems_packed, oi++) {
  305. if (num_elems_packed == 1) {
  306. // TODO(cmikeh2): refactor to use conversion utils
  307. local_output[i] = q_params.quantize(data[i]);
  308. } else if (num_elems_packed == 2) {
  309. int8_t data_i8_1 = q_params.quantize(data[i]);
  310. int8_t data_i8_2 = q_params.quantize(data[i + 1]);
  311. auto data_i8 = PackedInt4{data_i8_2, data_i8_1};
  312. local_output[oi] = *((int8_t*)(&data_i8));
  313. }
  314. }
  315. }
  316. template <int numBits, Type qType>
  317. DS_D_INLINE void _chunk(int8_t* local_output, const __half2* data, Params<qType, numBits> q_params)
  318. {
  319. const __half* data_cast = reinterpret_cast<const __half*>(data);
  320. _chunk<numBits>(local_output, data_cast, q_params);
  321. }
  322. template <Type qType, int numChunks>
  323. DS_D_INLINE GroupStats<qType> _local_serial_reduce(__half2* local_buffer)
  324. {
  325. GroupStats<qType> stats;
  326. #pragma unroll
  327. for (int i = 0; i < numChunks * h2_per_load; i++) { stats.update(local_buffer[i]); }
  328. return stats;
  329. }
  330. template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
  331. DS_D_INLINE void local_array(cg::thread_block& tb,
  332. cg::thread_block_tile<hw_warp_size>& warp,
  333. __half2* local_buffer,
  334. float* __restrict__ global_params,
  335. int8_t* __restrict__ output_data,
  336. const int& elems_per_group,
  337. const int& groups,
  338. Params<qType, numBits> q_params)
  339. {
  340. constexpr int num_ele_int8 = 8 / numBits;
  341. constexpr int num_int8_out = quantize::h_per_load / num_ele_int8;
  342. // Indexing offsets
  343. const int block_num =
  344. (tb.group_index().x * max_threads / threads_per_group) + tb.thread_index().y;
  345. const int block_offset = block_num * elems_per_group;
  346. const int elem_offset = tb.thread_index().x * quantize::h_per_load;
  347. const int base_offset = (block_offset + elem_offset) / num_ele_int8;
  348. const int stride = tb.size() * quantize::h_per_load / num_ele_int8;
  349. int8_t local_output[num_int8_out];
  350. if (tb.thread_index().x == 0 && block_num < groups) {
  351. q_params.store(
  352. global_params,
  353. (tb.group_index().x * max_threads / threads_per_group) + tb.thread_index().y);
  354. }
  355. #pragma unroll
  356. for (int i = 0; i < numChunks; i++) {
  357. if (elem_offset + i * stride * num_ele_int8 < elems_per_group && block_num < groups) {
  358. quantize::_chunk<numBits, qType>(
  359. local_output, local_buffer + i * quantize::h2_per_load, q_params);
  360. mem_access::store_global<num_int8_out>(output_data + (base_offset + i * stride),
  361. local_output);
  362. }
  363. }
  364. }
  365. template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
  366. DS_D_INLINE void local_array(cg::thread_block& tb,
  367. cg::thread_block_tile<hw_warp_size>& warp,
  368. __half* local_buffer,
  369. float* __restrict__ global_params,
  370. int8_t* __restrict__ output_data,
  371. const int& elems_per_group,
  372. const int& groups,
  373. Params<qType, numBits> q_params)
  374. {
  375. __half2* local_buffer_h2 = reinterpret_cast<__half2*>(local_buffer);
  376. quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
  377. tb, warp, local_buffer, global_params, output_data, elems_per_group, groups, q_params);
  378. }
  379. template <Type qType,
  380. int numBits,
  381. int numChunks,
  382. int threads_per_group = max_threads,
  383. int max_threads = 256>
  384. __device__ void local_array(__half2* local_buffer,
  385. float* __restrict__ global_params,
  386. int8_t* __restrict__ output_data,
  387. const int& elems_per_group,
  388. const int& groups)
  389. {
  390. cg::thread_block tb = cg::this_thread_block();
  391. cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
  392. auto group_stats = _local_serial_reduce<qType, numChunks>(local_buffer);
  393. auto params = group_stats.template get_params<numBits, threads_per_group>(tb, warp);
  394. quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
  395. tb, warp, local_buffer, global_params, output_data, elems_per_group, groups, params);
  396. }
  397. template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
  398. __device__ void local_array(__half* local_buffer,
  399. float* __restrict__ global_params,
  400. int8_t* __restrict__ output_data,
  401. const int& elems_per_group,
  402. const int& groups)
  403. {
  404. __half2* local_buffer_h2 = reinterpret_cast<__half2*>(local_buffer);
  405. quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
  406. local_buffer_h2, global_params, output_data, elems_per_group, groups);
  407. }
  408. } // namespace quantize