quantizer.cu 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. #include <math.h>
  2. #include "custom_cuda_layers.h"
  3. #include "memory_access_utils.h"
  4. namespace cg = cooperative_groups;
  5. __global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
  6. {
  7. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  8. cg::thread_block b = cg::this_thread_block(); // tb
  9. cg::thread_block_tile<32> g =
  10. cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64.
  11. int gid = threadIdx.x >> 5;
  12. int lane = threadIdx.x & 0x1f;
  13. int warp_num = blockDim.x >> 5;
  14. int id = threadIdx.x;
  15. constexpr int granularity = 16;
  16. constexpr int vals_per_access = granularity / sizeof(__half);
  17. __half data[vals_per_access];
  18. int group_id = blockIdx.x;
  19. int thread_index = id * vals_per_access;
  20. int reg_count = 0;
  21. int offset = group_id * group_size;
  22. float max = -10000.0;
  23. for (int thread_index = id * vals_per_access; thread_index < group_size;
  24. thread_index += blockDim.x * vals_per_access) {
  25. mem_access::load_global<granularity>(data, vals + offset + thread_index);
  26. #pragma unroll
  27. for (int i = 0; i < vals_per_access; i++) {
  28. if (abs((float)data[i]) > max) max = abs((float)data[i]);
  29. }
  30. }
  31. #pragma unroll
  32. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  33. auto temp = g.shfl_xor(max, i);
  34. if (max < temp) max = temp;
  35. }
  36. __shared__ float partialMax[WARP_SIZE];
  37. if (lane == 0) partialMax[gid] = max;
  38. b.sync();
  39. if (lane < warp_num) max = partialMax[lane];
  40. #pragma unroll
  41. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  42. auto temp = g.shfl_down(max, i);
  43. if (max < temp) max = temp;
  44. }
  45. max = g.shfl(max, 0);
  46. float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5);
  47. float q_scale_inv = 1 / q_scale;
  48. int q_range_max = (1 << (num_bits - 1)) - 1;
  49. int q_range_min = -(1 << (num_bits - 1));
  50. for (int thread_index = id * vals_per_access; thread_index < group_size;
  51. thread_index += blockDim.x * vals_per_access) {
  52. mem_access::load_global<granularity>(data, vals + offset + thread_index);
  53. #pragma unroll
  54. for (int j = 0; j < vals_per_access; j++) {
  55. float q_data;
  56. q_data = __half2float(data[j]);
  57. q_data = __float2int_rn(q_data * q_scale);
  58. q_data = q_data > (q_range_max) ? (q_range_max)
  59. : (q_data < (q_range_min) ? (q_range_min) : q_data);
  60. data[j] = __float2half_rn(q_data * q_scale_inv);
  61. }
  62. mem_access::store_global<granularity>(vals + offset + thread_index, data);
  63. }
  64. #endif
  65. }
  66. __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
  67. {
  68. cg::thread_block b = cg::this_thread_block();
  69. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  70. int gid = threadIdx.x >> 5;
  71. int lane = threadIdx.x & 0x1f;
  72. int warp_num = blockDim.x >> 5;
  73. int id = threadIdx.x;
  74. constexpr int granularity = 16;
  75. constexpr int vals_per_access = granularity / sizeof(float);
  76. float data[vals_per_access];
  77. int bid = blockIdx.x;
  78. int thread_index = id * vals_per_access;
  79. int reg_count = 0;
  80. int offset = bid * group_size;
  81. float max = -10000.0;
  82. for (int thread_index = id * vals_per_access; thread_index < group_size;
  83. thread_index += blockDim.x * vals_per_access) {
  84. mem_access::load_global<granularity>(data, vals + offset + thread_index);
  85. #pragma unroll
  86. for (int i = 0; i < vals_per_access; i++) {
  87. if (abs(data[i]) > max) max = abs(data[i]);
  88. }
  89. }
  90. #pragma unroll
  91. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  92. auto temp = g.shfl_xor(max, i);
  93. if (max < temp) max = temp;
  94. }
  95. __shared__ float partialMax[WARP_SIZE];
  96. if (lane == 0) partialMax[gid] = max;
  97. b.sync();
  98. if (lane < warp_num) max = partialMax[lane];
  99. b.sync();
  100. #pragma unroll
  101. for (int i = 1; i < warp_num; i <<= 1) {
  102. auto temp = g.shfl_down(max, i);
  103. if (max < temp) max = temp;
  104. }
  105. max = g.shfl(max, 0);
  106. float q_scale = (1 << num_bits) / (2 * max + 1e-5);
  107. float q_scale_inv = 1 / q_scale;
  108. int q_range_max = (1 << (num_bits - 1)) - 1;
  109. int q_range_min = -(1 << (num_bits - 1));
  110. for (int thread_index = id * vals_per_access; thread_index < group_size;
  111. thread_index += blockDim.x * vals_per_access) {
  112. mem_access::load_global<granularity>(data, vals + offset + thread_index);
  113. #pragma unroll
  114. for (int j = 0; j < vals_per_access; j++) {
  115. float q_data;
  116. q_data = __float2int_rn(data[j] * q_scale);
  117. q_data = q_data > (q_range_max) ? (q_range_max)
  118. : (q_data < (q_range_min) ? (q_range_min) : q_data);
  119. data[j] = roundf(q_data * q_scale_inv);
  120. }
  121. mem_access::store_global<granularity>(vals + offset + thread_index, data);
  122. }
  123. }
  124. template <typename T>
  125. void launch_quantize_kernel(T* vals,
  126. int total_count,
  127. int group_num,
  128. int num_bits,
  129. cudaStream_t stream)
  130. {
  131. dim3 grid_dim(group_num);
  132. dim3 block_dim(1024);
  133. quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(vals, total_count / group_num, num_bits);
  134. }
  135. template void launch_quantize_kernel(float* vals,
  136. int total_count,
  137. int group_num,
  138. int num_bits,
  139. cudaStream_t stream);
  140. template void launch_quantize_kernel(__half* vals,
  141. int total_count,
  142. int group_num,
  143. int num_bits,
  144. cudaStream_t stream);
  145. __global__ void sr_quantize_kernel(__half* vals,
  146. int token_size,
  147. int token_num,
  148. int num_bits,
  149. std::pair<uint64_t, uint64_t> seed)
  150. {
  151. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  152. cg::thread_block b = cg::this_thread_block();
  153. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  154. int gid = threadIdx.x >> 5;
  155. int lane = threadIdx.x & 0x1f;
  156. int warp_num = blockDim.x >> 5;
  157. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  158. float2* vals_cast = reinterpret_cast<float2*>(vals);
  159. __half2 data_low[128];
  160. __half2 data_high[128];
  161. int bid = blockIdx.x;
  162. curandStatePhilox4_32_10_t state;
  163. curand_init(seed.first, idx, seed.second, &state);
  164. unsigned int tid = threadIdx.x;
  165. int reg_count = 0;
  166. int offset = bid * token_size;
  167. int group_index = bid * token_size + tid;
  168. int total_count = token_size * token_num;
  169. if (group_index < total_count) {
  170. // float min = 10000.0;
  171. float max = -10000.0;
  172. while (tid < token_size) {
  173. float2 data = vals_cast[offset + tid];
  174. __half2* data_h = reinterpret_cast<__half2*>(&data);
  175. data_low[reg_count] = data_h[0];
  176. data_high[reg_count] = data_h[1];
  177. float2 data_f[2];
  178. data_f[0] = __half22float2(data_h[0]);
  179. data_f[1] = __half22float2(data_h[1]);
  180. if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
  181. if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
  182. if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
  183. if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
  184. tid += blockDim.x;
  185. reg_count++;
  186. }
  187. #pragma unroll
  188. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  189. auto temp = g.shfl_xor(max, i);
  190. if (max < temp) max = temp;
  191. }
  192. __shared__ float partialMax[WARP_SIZE];
  193. if (lane == 0) partialMax[gid] = max;
  194. b.sync();
  195. if (lane < warp_num) max = partialMax[lane];
  196. #pragma unroll
  197. for (int i = 1; i < warp_num; i <<= 1) {
  198. auto temp = g.shfl_down(max, i);
  199. if (max < temp) max = temp;
  200. }
  201. max = g.shfl(max, 0);
  202. float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
  203. float high_q = (float)((1 << (num_bits - 1)) - 1);
  204. float low_q = (float)(-((1 << (num_bits - 1))));
  205. for (int i = 0; i < reg_count; i++) {
  206. int token_index = i * blockDim.x + threadIdx.x;
  207. if (token_index < token_size) {
  208. float2 data_f[2];
  209. data_f[0] = __half22float2(data_low[i]);
  210. data_f[1] = __half22float2(data_high[i]);
  211. float2 q_data_int[2];
  212. q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
  213. q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
  214. q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
  215. q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
  216. // Stochastic rounding
  217. float4 rand = curand_uniform4(&state);
  218. float q_error[4];
  219. q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
  220. q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
  221. q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
  222. q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
  223. q_data_int[0].x =
  224. (rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
  225. ? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
  226. : q_data_int[0].x;
  227. q_data_int[0].y =
  228. (rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
  229. ? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
  230. : q_data_int[0].y;
  231. q_data_int[1].x =
  232. (rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
  233. ? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
  234. : q_data_int[1].x;
  235. q_data_int[1].y =
  236. (rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
  237. ? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
  238. : q_data_int[1].y;
  239. data_f[0].x = q_data_int[0].x / q_scale_val;
  240. data_f[0].y = q_data_int[0].y / q_scale_val;
  241. data_f[1].x = q_data_int[1].x / q_scale_val;
  242. data_f[1].y = q_data_int[1].y / q_scale_val;
  243. float2 result;
  244. __half2* result_h = reinterpret_cast<__half2*>(&result);
  245. result_h[0] = __float22half2_rn(data_f[0]);
  246. result_h[1] = __float22half2_rn(data_f[1]);
  247. vals_cast[offset + token_index] = result;
  248. }
  249. }
  250. }
  251. #endif
  252. }
  253. __global__ void sr_quantize_kernel(float* vals,
  254. int token_size,
  255. int token_num,
  256. int num_bits,
  257. std::pair<uint64_t, uint64_t> seed)
  258. {
  259. cg::thread_block b = cg::this_thread_block();
  260. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  261. int gid = threadIdx.x >> 5;
  262. int lane = threadIdx.x & 0x1f;
  263. int warp_num = blockDim.x >> 5;
  264. int id = threadIdx.x;
  265. int idx = blockIdx.x * blockDim.x + id;
  266. float4* vals_cast = reinterpret_cast<float4*>(vals);
  267. float4 data[128];
  268. int bid = blockIdx.x;
  269. int tid = threadIdx.x;
  270. curandStatePhilox4_32_10_t state;
  271. curand_init(seed.first, idx, seed.second, &state);
  272. int group_index = bid * token_size + threadIdx.x;
  273. int reg_count = 0;
  274. int total_count = token_size * token_num;
  275. if (group_index < total_count) {
  276. // float min = 10000.0;
  277. float max = -10000.0;
  278. while (tid < token_size) {
  279. data[reg_count] = vals_cast[group_index];
  280. if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
  281. if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
  282. if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
  283. if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
  284. group_index += blockDim.x;
  285. tid += blockDim.x;
  286. reg_count++;
  287. }
  288. #pragma unroll
  289. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  290. auto temp = g.shfl_xor(max, i);
  291. if (max < temp) max = temp;
  292. }
  293. __shared__ float partialMax[WARP_SIZE];
  294. if (lane == 0) partialMax[gid] = max;
  295. b.sync();
  296. if (lane < warp_num) max = partialMax[lane];
  297. #pragma unroll
  298. for (int i = 1; i < warp_num; i <<= 1) {
  299. auto temp = g.shfl_down(max, i);
  300. if (max < temp) max = temp;
  301. }
  302. max = g.shfl(max, 0);
  303. float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
  304. float high_q = (float)((1 << (num_bits - 1)) - 1);
  305. float low_q = (float)(-((1 << (num_bits - 1))));
  306. int offset = (bid)*token_size;
  307. for (int i = 0; i < reg_count; i++) {
  308. group_index = i * blockDim.x + threadIdx.x;
  309. if (group_index < token_size) {
  310. float4 q_data = data[i];
  311. float4 q_data_int;
  312. q_data_int.x = (float)((int)(q_data.x * q_scale_val));
  313. q_data_int.y = (float)((int)(q_data.y * q_scale_val));
  314. q_data_int.w = (float)((int)(q_data.w * q_scale_val));
  315. q_data_int.z = (float)((int)(q_data.z * q_scale_val));
  316. // Stochastic rounding
  317. float4 rand = curand_uniform4(&state);
  318. float q_error[4];
  319. q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
  320. q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
  321. q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
  322. q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
  323. q_data_int.x =
  324. (rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
  325. ? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
  326. : q_data_int.x;
  327. q_data_int.y =
  328. (rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
  329. ? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
  330. : q_data_int.y;
  331. q_data_int.w =
  332. (rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
  333. ? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
  334. : q_data_int.w;
  335. q_data_int.z =
  336. (rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
  337. ? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
  338. : q_data_int.z;
  339. q_data_int.x /= q_scale_val;
  340. q_data_int.y /= q_scale_val;
  341. q_data_int.w /= q_scale_val;
  342. q_data_int.z /= q_scale_val;
  343. vals_cast[group_index + offset] = q_data_int;
  344. }
  345. }
  346. }
  347. }
  348. template <typename T>
  349. void launch_sr_quantize_kernel(T* vals,
  350. int total_count,
  351. int group_num,
  352. int num_bits,
  353. cudaStream_t stream)
  354. {
  355. dim3 block_dim(1024);
  356. dim3 grid_dim(group_num);
  357. uint64_t inc = total_count / grid_dim.x / block_dim.x;
  358. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  359. sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
  360. vals, (total_count / group_num) / 4, group_num, num_bits, seed);
  361. }
  362. template void launch_sr_quantize_kernel(float* vals,
  363. int total_count,
  364. int group_num,
  365. int num_bits,
  366. cudaStream_t stream);
  367. template void launch_sr_quantize_kernel(__half* vals,
  368. int total_count,
  369. int group_num,
  370. int num_bits,
  371. cudaStream_t stream);
  372. __global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
  373. {
  374. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  375. cg::thread_block b = cg::this_thread_block();
  376. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  377. int gid = threadIdx.x >> 5;
  378. int lane = threadIdx.x & 0x1f;
  379. int warp_num = blockDim.x >> 5;
  380. int id = threadIdx.x;
  381. float2* vals_cast = reinterpret_cast<float2*>(vals);
  382. float2 data[MAX_REG];
  383. int group_id = blockIdx.x;
  384. {
  385. int group_index = id;
  386. int reg_count = 0;
  387. int offset = group_id * group_size;
  388. float max = -10000.0;
  389. float min = 10000.0;
  390. while (group_index < group_size && reg_count < MAX_REG) {
  391. data[reg_count] = vals_cast[offset + group_index];
  392. __half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
  393. if (((float)data_h[0]) > max) max = (float)data_h[0];
  394. if (((float)data_h[1]) > max) max = (float)data_h[1];
  395. if (((float)data_h[2]) > max) max = (float)data_h[2];
  396. if (((float)data_h[3]) > max) max = (float)data_h[3];
  397. if (((float)data_h[0]) < min) min = (float)data_h[0];
  398. if (((float)data_h[1]) < min) min = (float)data_h[1];
  399. if (((float)data_h[2]) < min) min = (float)data_h[2];
  400. if (((float)data_h[3]) < min) min = (float)data_h[3];
  401. group_index += blockDim.x;
  402. reg_count++;
  403. }
  404. #pragma unroll
  405. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  406. auto temp = g.shfl_xor(max, i);
  407. if (max < temp) max = temp;
  408. }
  409. #pragma unroll
  410. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  411. auto temp = g.shfl_xor(min, i);
  412. if (min > temp) min = temp;
  413. }
  414. __shared__ float partialMax[WARP_SIZE];
  415. __shared__ float partialMin[WARP_SIZE];
  416. if (lane == 0) partialMax[gid] = max;
  417. if (lane == 0) partialMin[gid] = min;
  418. b.sync();
  419. if (lane < warp_num) max = partialMax[lane];
  420. if (lane < warp_num) min = partialMin[lane];
  421. #pragma unroll
  422. for (int i = 1; i < warp_num; i <<= 1) {
  423. auto temp = g.shfl_down(max, i);
  424. if (max < temp) max = temp;
  425. }
  426. #pragma unroll
  427. for (int i = 1; i < warp_num; i <<= 1) {
  428. auto temp = g.shfl_down(min, i);
  429. if (min > temp) min = temp;
  430. }
  431. max = g.shfl(max, 0);
  432. min = g.shfl(min, 0);
  433. float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
  434. float q_scale_inv = 1 / q_scale;
  435. for (int i = 0; i < reg_count; i++) {
  436. group_index = i * blockDim.x + id;
  437. if (group_index < group_size) {
  438. __half2* data_h = reinterpret_cast<__half2*>(&data[i]);
  439. float2 q_data[2];
  440. q_data[0] = __half22float2(data_h[0]);
  441. q_data[1] = __half22float2(data_h[1]);
  442. float2 q_data_int[2];
  443. q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
  444. q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
  445. q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
  446. q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
  447. q_data_int[0].x = q_data_int[0].x * q_scale + min;
  448. q_data_int[0].y = q_data_int[0].y * q_scale + min;
  449. q_data_int[1].x = q_data_int[1].x * q_scale + min;
  450. q_data_int[1].y = q_data_int[1].y * q_scale + min;
  451. data_h[0] = __float22half2_rn(q_data_int[0]);
  452. data_h[1] = __float22half2_rn(q_data_int[1]);
  453. vals_cast[offset + group_index] = data[i];
  454. }
  455. }
  456. }
  457. #endif
  458. }
  459. __global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
  460. {
  461. cg::thread_block b = cg::this_thread_block();
  462. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  463. int gid = threadIdx.x >> 5;
  464. int lane = threadIdx.x & 0x1f;
  465. int warp_num = blockDim.x >> 5;
  466. int id = threadIdx.x;
  467. float4* vals_cast = reinterpret_cast<float4*>(vals);
  468. float4 data[MAX_REG];
  469. int bid = blockIdx.x;
  470. int group_index = bid * group_size + id;
  471. int reg_count = 0;
  472. float max = -10000.0;
  473. float min = 10000.0;
  474. while (id < group_size && reg_count < MAX_REG) {
  475. float4 data_reg = vals_cast[group_index];
  476. data[reg_count] = data_reg;
  477. if (data_reg.x > max) max = data_reg.x;
  478. if (data_reg.y > max) max = data_reg.y;
  479. if (data_reg.w > max) max = data_reg.w;
  480. if (data_reg.z > max) max = data_reg.z;
  481. if (data_reg.x < min) min = data_reg.x;
  482. if (data_reg.y < min) min = data_reg.y;
  483. if (data_reg.w < min) min = data_reg.w;
  484. if (data_reg.z < min) min = data_reg.z;
  485. group_index += blockDim.x;
  486. id += blockDim.x;
  487. reg_count++;
  488. }
  489. id = threadIdx.x;
  490. #pragma unroll
  491. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  492. auto temp = g.shfl_xor(max, i);
  493. if (max < temp) max = temp;
  494. }
  495. #pragma unroll
  496. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  497. auto temp = g.shfl_xor(min, i);
  498. if (min > temp) min = temp;
  499. }
  500. __shared__ float partialMax[WARP_SIZE];
  501. __shared__ float partialMin[WARP_SIZE];
  502. if (lane == 0) partialMax[gid] = max;
  503. if (lane == 0) partialMin[gid] = min;
  504. b.sync();
  505. if (lane < warp_num) max = partialMax[lane];
  506. if (lane < warp_num) min = partialMin[lane];
  507. #pragma unroll
  508. for (int i = 1; i < warp_num; i <<= 1) {
  509. auto temp = g.shfl_down(max, i);
  510. if (max < temp) max = temp;
  511. }
  512. #pragma unroll
  513. for (int i = 1; i < warp_num; i <<= 1) {
  514. auto temp = g.shfl_down(min, i);
  515. if (min > temp) min = temp;
  516. }
  517. max = g.shfl(max, 0);
  518. min = g.shfl(min, 0);
  519. float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
  520. float q_scale_inv = 1 / q_scale;
  521. for (int i = 0; i < reg_count; i++) {
  522. group_index = i * blockDim.x + id;
  523. if (group_index < group_size) {
  524. float4 q_data;
  525. q_data = data[i];
  526. float4 q_data_int;
  527. q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
  528. q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
  529. q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
  530. q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
  531. q_data.x = q_data_int.x * q_scale + min;
  532. q_data.y = q_data_int.y * q_scale + min;
  533. q_data.w = q_data_int.w * q_scale + min;
  534. q_data.z = q_data_int.z * q_scale + min;
  535. vals_cast[group_index + bid * group_size] = q_data;
  536. }
  537. }
  538. }
  539. template <typename T>
  540. void launch_quantize_kernel_asym(T* vals,
  541. int total_count,
  542. int group_num,
  543. int num_bits,
  544. cudaStream_t stream)
  545. {
  546. dim3 grid_dim(group_num);
  547. dim3 block_dim(1024);
  548. quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
  549. vals, (total_count / group_num) / 4, num_bits);
  550. }
  551. template void launch_quantize_kernel_asym(float* vals,
  552. int total_count,
  553. int group_num,
  554. int num_bits,
  555. cudaStream_t stream);
  556. template void launch_quantize_kernel_asym(__half* vals,
  557. int total_count,
  558. int group_num,
  559. int num_bits,
  560. cudaStream_t stream);
  561. __global__ void sr_quantize_kernel_asym(__half* vals,
  562. int token_size,
  563. int token_num,
  564. int num_bits,
  565. std::pair<uint64_t, uint64_t> seed)
  566. {
  567. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  568. cg::thread_block b = cg::this_thread_block();
  569. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  570. int gid = threadIdx.x >> 5;
  571. int lane = threadIdx.x & 0x1f;
  572. int warp_num = blockDim.x >> 5;
  573. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  574. float2* vals_cast = reinterpret_cast<float2*>(vals);
  575. __half2 data_low[128];
  576. __half2 data_high[128];
  577. int bid = blockIdx.x;
  578. curandStatePhilox4_32_10_t state;
  579. curand_init(seed.first, idx, seed.second, &state);
  580. unsigned int tid = threadIdx.x;
  581. int reg_count = 0;
  582. int offset = bid * token_size;
  583. int group_index = bid * token_size + tid;
  584. int total_count = token_size * token_num;
  585. if (group_index < total_count) {
  586. float min = 10000.0;
  587. float max = -10000.0;
  588. while (tid < token_size) {
  589. float2 data = vals_cast[offset + tid];
  590. __half2* data_h = reinterpret_cast<__half2*>(&data);
  591. data_low[reg_count] = data_h[0];
  592. data_high[reg_count] = data_h[1];
  593. float2 data_f[2];
  594. data_f[0] = __half22float2(data_h[0]);
  595. data_f[1] = __half22float2(data_h[1]);
  596. if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
  597. if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
  598. if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
  599. if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
  600. if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
  601. if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
  602. if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
  603. if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
  604. tid += blockDim.x;
  605. reg_count++;
  606. }
  607. #pragma unroll
  608. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  609. auto temp = g.shfl_xor(max, i);
  610. if (max < temp) max = temp;
  611. }
  612. #pragma unroll
  613. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  614. auto temp = g.shfl_xor(min, i);
  615. if (min > temp) min = temp;
  616. }
  617. __shared__ float partialMax[WARP_SIZE];
  618. __shared__ float partialMin[WARP_SIZE];
  619. if (lane == 0) partialMax[gid] = max;
  620. if (lane == 0) partialMin[gid] = min;
  621. b.sync();
  622. if (lane < warp_num) max = partialMax[lane];
  623. if (lane < warp_num) min = partialMin[lane];
  624. #pragma unroll
  625. for (int i = 1; i < warp_num; i <<= 1) {
  626. auto temp = g.shfl_down(max, i);
  627. if (max < temp) max = temp;
  628. }
  629. #pragma unroll
  630. for (int i = 1; i < warp_num; i <<= 1) {
  631. auto temp = g.shfl_down(min, i);
  632. if (min > temp) min = temp;
  633. }
  634. max = g.shfl(max, 0);
  635. min = g.shfl(min, 0);
  636. float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
  637. float q_scale_val_inv = 1 / q_scale_val;
  638. float high_q = (float)((1 << num_bits) - 1);
  639. for (int i = 0; i < reg_count; i++) {
  640. int token_index = i * blockDim.x + threadIdx.x;
  641. if (token_index < token_size) {
  642. float2 data_f[2];
  643. data_f[0] = __half22float2(data_low[i]);
  644. data_f[1] = __half22float2(data_high[i]);
  645. float2 q_data_int[2];
  646. q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
  647. q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
  648. q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
  649. q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
  650. // Stochastic rounding
  651. float4 rand = curand_uniform4(&state);
  652. float q_error[4];
  653. q_error[0] =
  654. abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
  655. q_error[1] =
  656. abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
  657. q_error[2] =
  658. abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
  659. q_error[3] =
  660. abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
  661. q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
  662. ? (q_data_int[0].x + 1)
  663. : q_data_int[0].x;
  664. q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
  665. ? (q_data_int[0].y + 1)
  666. : q_data_int[0].y;
  667. q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
  668. ? (q_data_int[1].x + 1)
  669. : q_data_int[1].x;
  670. q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
  671. ? (q_data_int[1].y + 1)
  672. : q_data_int[1].y;
  673. data_f[0].x = q_data_int[0].x * q_scale_val + min;
  674. data_f[0].y = q_data_int[0].y * q_scale_val + min;
  675. data_f[1].x = q_data_int[1].x * q_scale_val + min;
  676. data_f[1].y = q_data_int[1].y * q_scale_val + min;
  677. float2 result;
  678. __half2* result_h = reinterpret_cast<__half2*>(&result);
  679. result_h[0] = __float22half2_rn(data_f[0]);
  680. result_h[1] = __float22half2_rn(data_f[1]);
  681. vals_cast[offset + token_index] = result;
  682. }
  683. }
  684. }
  685. #endif
  686. }
  687. __global__ void sr_quantize_kernel_asym(float* vals,
  688. int token_size,
  689. int token_num,
  690. int num_bits,
  691. std::pair<uint64_t, uint64_t> seed)
  692. {
  693. cg::thread_block b = cg::this_thread_block();
  694. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  695. int gid = threadIdx.x >> 5;
  696. int lane = threadIdx.x & 0x1f;
  697. int warp_num = blockDim.x >> 5;
  698. int id = threadIdx.x;
  699. int idx = blockIdx.x * blockDim.x + id;
  700. float4* vals_cast = reinterpret_cast<float4*>(vals);
  701. float4 data[128];
  702. int bid = blockIdx.x;
  703. int tid = threadIdx.x;
  704. curandStatePhilox4_32_10_t state;
  705. curand_init(seed.first, idx, seed.second, &state);
  706. int group_index = bid * token_size + threadIdx.x;
  707. int reg_count = 0;
  708. int total_count = token_size * token_num;
  709. if (group_index < total_count) {
  710. float min = 10000.0;
  711. float max = -10000.0;
  712. while (tid < token_size) {
  713. float4 data_reg = vals_cast[group_index];
  714. data[reg_count] = data_reg;
  715. if (data_reg.x > max) max = data_reg.x;
  716. if (data_reg.y > max) max = data_reg.y;
  717. if (data_reg.w > max) max = data_reg.w;
  718. if (data_reg.z > max) max = data_reg.z;
  719. if (data_reg.x < min) min = data_reg.x;
  720. if (data_reg.y < min) min = data_reg.y;
  721. if (data_reg.w < min) min = data_reg.w;
  722. if (data_reg.z < min) min = data_reg.z;
  723. group_index += blockDim.x;
  724. tid += blockDim.x;
  725. reg_count++;
  726. }
  727. #pragma unroll
  728. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  729. auto temp = g.shfl_xor(max, i);
  730. if (max < temp) max = temp;
  731. }
  732. #pragma unroll
  733. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  734. auto temp = g.shfl_xor(min, i);
  735. if (min > temp) min = temp;
  736. }
  737. __shared__ float partialMax[WARP_SIZE];
  738. __shared__ float partialMin[WARP_SIZE];
  739. if (lane == 0) partialMax[gid] = max;
  740. if (lane == 0) partialMin[gid] = min;
  741. b.sync();
  742. if (lane < warp_num) max = partialMax[lane];
  743. if (lane < warp_num) min = partialMin[lane];
  744. #pragma unroll
  745. for (int i = 1; i < warp_num; i <<= 1) {
  746. auto temp = g.shfl_down(max, i);
  747. if (max < temp) max = temp;
  748. }
  749. #pragma unroll
  750. for (int i = 1; i < warp_num; i <<= 1) {
  751. auto temp = g.shfl_down(min, i);
  752. if (min > temp) min = temp;
  753. }
  754. max = g.shfl(max, 0);
  755. min = g.shfl(min, 0);
  756. float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
  757. float high_q = (float)((1 << num_bits) - 1);
  758. int offset = (bid)*token_size;
  759. for (int i = 0; i < reg_count; i++) {
  760. group_index = i * blockDim.x + threadIdx.x;
  761. if (group_index < token_size) {
  762. float4 q_data = data[i];
  763. float4 q_data_int;
  764. q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
  765. q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
  766. q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
  767. q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
  768. // Stochastic rounding
  769. float4 rand = curand_uniform4(&state);
  770. float q_error[4];
  771. q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
  772. q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
  773. q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
  774. q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
  775. q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
  776. : q_data_int.x;
  777. q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
  778. : q_data_int.y;
  779. q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
  780. : q_data_int.w;
  781. q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
  782. : q_data_int.z;
  783. q_data_int.x = q_data_int.x * q_scale_val + min;
  784. q_data_int.y = q_data_int.y * q_scale_val + min;
  785. q_data_int.w = q_data_int.w * q_scale_val + min;
  786. q_data_int.z = q_data_int.z * q_scale_val + min;
  787. vals_cast[group_index + offset] = q_data_int;
  788. }
  789. }
  790. }
  791. }
  792. template <typename T>
  793. void launch_sr_quantize_kernel_asym(T* vals,
  794. int total_count,
  795. int group_num,
  796. int num_bits,
  797. cudaStream_t stream)
  798. {
  799. dim3 block_dim(1024);
  800. dim3 grid_dim(group_num);
  801. uint64_t inc = total_count / grid_dim.x / block_dim.x;
  802. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  803. sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
  804. vals, (total_count / group_num) / 4, group_num, num_bits, seed);
  805. }
  806. template void launch_sr_quantize_kernel_asym(float* vals,
  807. int total_count,
  808. int group_num,
  809. int num_bits,
  810. cudaStream_t stream);
  811. template void launch_sr_quantize_kernel_asym(__half* vals,
  812. int total_count,
  813. int group_num,
  814. int num_bits,
  815. cudaStream_t stream);