quantizer.cu 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037
  1. #include <math.h>
  2. #include "custom_cuda_layers.h"
  3. namespace cg = cooperative_groups;
  4. __global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
  5. {
  6. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  7. cg::thread_block b = cg::this_thread_block();
  8. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  9. int gid = threadIdx.x >> 5;
  10. int lane = threadIdx.x & 0x1f;
  11. int warp_num = blockDim.x >> 5;
  12. int id = threadIdx.x;
  13. float2* vals_cast = reinterpret_cast<float2*>(vals);
  14. float2 data[MAX_REG];
  15. int group_id = blockIdx.x;
  16. {
  17. int group_index = id;
  18. int reg_count = 0;
  19. int offset = group_id * group_size;
  20. float max = -10000.0;
  21. while (group_index < group_size && reg_count < MAX_REG) {
  22. data[reg_count] = vals_cast[offset + group_index];
  23. __half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
  24. if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
  25. if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
  26. if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
  27. if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
  28. group_index += blockDim.x;
  29. reg_count++;
  30. }
  31. #pragma unroll
  32. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  33. auto temp = g.shfl_xor(max, i);
  34. if (max < temp) max = temp;
  35. }
  36. __shared__ float partialMax[WARP_SIZE];
  37. if (lane == 0) partialMax[gid] = max;
  38. b.sync();
  39. if (lane < warp_num) max = partialMax[lane];
  40. #pragma unroll
  41. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  42. auto temp = g.shfl_down(max, i);
  43. if (max < temp) max = temp;
  44. }
  45. max = g.shfl(max, 0);
  46. float q_scale = (1 << num_bits) / (2 * max + 1e-5);
  47. float q_scale_inv = 1 / q_scale;
  48. for (int i = 0; i < reg_count; i++) {
  49. group_index = i * blockDim.x + id;
  50. if (group_index < group_size) {
  51. __half2* data_h = reinterpret_cast<__half2*>(&data[i]);
  52. float2 q_data[2];
  53. q_data[0] = __half22float2(data_h[0]);
  54. q_data[1] = __half22float2(data_h[1]);
  55. float2 q_data_int[2];
  56. q_data_int[0].x = roundf(q_data[0].x * q_scale);
  57. q_data_int[0].y = roundf(q_data[0].y * q_scale);
  58. q_data_int[1].x = roundf(q_data[1].x * q_scale);
  59. q_data_int[1].y = roundf(q_data[1].y * q_scale);
  60. q_data_int[0].x *= q_scale_inv;
  61. q_data_int[0].y *= q_scale_inv;
  62. q_data_int[1].x *= q_scale_inv;
  63. q_data_int[1].y *= q_scale_inv;
  64. data_h[0] = __float22half2_rn(q_data_int[0]);
  65. data_h[1] = __float22half2_rn(q_data_int[1]);
  66. vals_cast[offset + group_index] = data[i];
  67. }
  68. }
  69. }
  70. #endif
  71. }
  72. __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
  73. {
  74. cg::thread_block b = cg::this_thread_block();
  75. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  76. int gid = threadIdx.x >> 5;
  77. int lane = threadIdx.x & 0x1f;
  78. int warp_num = blockDim.x >> 5;
  79. int id = threadIdx.x;
  80. float4* vals_cast = reinterpret_cast<float4*>(vals);
  81. float4 data[MAX_REG];
  82. int bid = blockIdx.x;
  83. int group_index = bid * group_size + id;
  84. int reg_count = 0;
  85. float max = -10000.0;
  86. while (id < group_size && reg_count < MAX_REG) {
  87. float4 data_reg = vals_cast[group_index];
  88. data[reg_count] = data_reg;
  89. if (abs(data_reg.x) > max) max = abs(data_reg.x);
  90. if (abs(data_reg.y) > max) max = abs(data_reg.y);
  91. if (abs(data_reg.z) > max) max = abs(data_reg.z);
  92. if (abs(data_reg.w) > max) max = abs(data_reg.w);
  93. group_index += blockDim.x;
  94. id += blockDim.x;
  95. reg_count++;
  96. }
  97. id = threadIdx.x;
  98. #pragma unroll
  99. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  100. auto temp = g.shfl_xor(max, i);
  101. if (max < temp) max = temp;
  102. }
  103. __shared__ float partialMax[WARP_SIZE];
  104. if (lane == 0) partialMax[gid] = max;
  105. b.sync();
  106. if (lane < warp_num) max = partialMax[lane];
  107. b.sync();
  108. #pragma unroll
  109. for (int i = 1; i < warp_num; i <<= 1) {
  110. auto temp = g.shfl_down(max, i);
  111. if (max < temp) max = temp;
  112. }
  113. max = g.shfl(max, 0);
  114. float q_scale = (1 << num_bits) / (2 * max + 1e-5);
  115. float q_scale_inv = 1 / q_scale;
  116. for (int i = 0; i < reg_count; i++) {
  117. group_index = i * blockDim.x + id;
  118. if (group_index < group_size) {
  119. float4 q_data;
  120. q_data = data[i];
  121. float4 q_data_int;
  122. q_data_int.x = roundf(q_data.x * q_scale);
  123. q_data_int.y = roundf(q_data.y * q_scale);
  124. q_data_int.w = roundf(q_data.w * q_scale);
  125. q_data_int.z = roundf(q_data.z * q_scale);
  126. q_data.x = q_data_int.x * q_scale_inv;
  127. q_data.y = q_data_int.y * q_scale_inv;
  128. q_data.w = q_data_int.w * q_scale_inv;
  129. q_data.z = q_data_int.z * q_scale_inv;
  130. vals_cast[group_index + bid * group_size] = q_data;
  131. }
  132. }
  133. }
  134. template <typename T>
  135. void launch_quantize_kernel(T* vals,
  136. int total_count,
  137. int group_num,
  138. int num_bits,
  139. cudaStream_t stream)
  140. {
  141. dim3 grid_dim(group_num);
  142. dim3 block_dim(1024);
  143. quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
  144. vals, (total_count / group_num) / 4, num_bits);
  145. }
  146. template void launch_quantize_kernel(float* vals,
  147. int total_count,
  148. int group_num,
  149. int num_bits,
  150. cudaStream_t stream);
  151. template void launch_quantize_kernel(__half* vals,
  152. int total_count,
  153. int group_num,
  154. int num_bits,
  155. cudaStream_t stream);
  156. __global__ void sr_quantize_kernel(__half* vals,
  157. int token_size,
  158. int token_num,
  159. int num_bits,
  160. std::pair<uint64_t, uint64_t> seed)
  161. {
  162. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  163. cg::thread_block b = cg::this_thread_block();
  164. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  165. int gid = threadIdx.x >> 5;
  166. int lane = threadIdx.x & 0x1f;
  167. int warp_num = blockDim.x >> 5;
  168. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  169. float2* vals_cast = reinterpret_cast<float2*>(vals);
  170. __half2 data_low[128];
  171. __half2 data_high[128];
  172. int bid = blockIdx.x;
  173. curandStatePhilox4_32_10_t state;
  174. curand_init(seed.first, idx, seed.second, &state);
  175. unsigned int tid = threadIdx.x;
  176. int reg_count = 0;
  177. int offset = bid * token_size;
  178. int group_index = bid * token_size + tid;
  179. int total_count = token_size * token_num;
  180. if (group_index < total_count) {
  181. // float min = 10000.0;
  182. float max = -10000.0;
  183. while (tid < token_size) {
  184. float2 data = vals_cast[offset + tid];
  185. __half2* data_h = reinterpret_cast<__half2*>(&data);
  186. data_low[reg_count] = data_h[0];
  187. data_high[reg_count] = data_h[1];
  188. float2 data_f[2];
  189. data_f[0] = __half22float2(data_h[0]);
  190. data_f[1] = __half22float2(data_h[1]);
  191. if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
  192. if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
  193. if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
  194. if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
  195. tid += blockDim.x;
  196. reg_count++;
  197. }
  198. #pragma unroll
  199. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  200. auto temp = g.shfl_xor(max, i);
  201. if (max < temp) max = temp;
  202. }
  203. __shared__ float partialMax[WARP_SIZE];
  204. if (lane == 0) partialMax[gid] = max;
  205. b.sync();
  206. if (lane < warp_num) max = partialMax[lane];
  207. #pragma unroll
  208. for (int i = 1; i < warp_num; i <<= 1) {
  209. auto temp = g.shfl_down(max, i);
  210. if (max < temp) max = temp;
  211. }
  212. max = g.shfl(max, 0);
  213. float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
  214. float high_q = (float)((1 << (num_bits - 1)) - 1);
  215. float low_q = (float)(-((1 << (num_bits - 1))));
  216. for (int i = 0; i < reg_count; i++) {
  217. int token_index = i * blockDim.x + threadIdx.x;
  218. if (token_index < token_size) {
  219. float2 data_f[2];
  220. data_f[0] = __half22float2(data_low[i]);
  221. data_f[1] = __half22float2(data_high[i]);
  222. float2 q_data_int[2];
  223. q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
  224. q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
  225. q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
  226. q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
  227. // Stochastic rounding
  228. float4 rand = curand_uniform4(&state);
  229. float q_error[4];
  230. q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
  231. q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
  232. q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
  233. q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
  234. q_data_int[0].x =
  235. (rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
  236. ? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
  237. : q_data_int[0].x;
  238. q_data_int[0].y =
  239. (rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
  240. ? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
  241. : q_data_int[0].y;
  242. q_data_int[1].x =
  243. (rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
  244. ? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
  245. : q_data_int[1].x;
  246. q_data_int[1].y =
  247. (rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
  248. ? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
  249. : q_data_int[1].y;
  250. data_f[0].x = q_data_int[0].x / q_scale_val;
  251. data_f[0].y = q_data_int[0].y / q_scale_val;
  252. data_f[1].x = q_data_int[1].x / q_scale_val;
  253. data_f[1].y = q_data_int[1].y / q_scale_val;
  254. float2 result;
  255. __half2* result_h = reinterpret_cast<__half2*>(&result);
  256. result_h[0] = __float22half2_rn(data_f[0]);
  257. result_h[1] = __float22half2_rn(data_f[1]);
  258. vals_cast[offset + token_index] = result;
  259. }
  260. }
  261. }
  262. #endif
  263. }
  264. __global__ void sr_quantize_kernel(float* vals,
  265. int token_size,
  266. int token_num,
  267. int num_bits,
  268. std::pair<uint64_t, uint64_t> seed)
  269. {
  270. cg::thread_block b = cg::this_thread_block();
  271. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  272. int gid = threadIdx.x >> 5;
  273. int lane = threadIdx.x & 0x1f;
  274. int warp_num = blockDim.x >> 5;
  275. int id = threadIdx.x;
  276. int idx = blockIdx.x * blockDim.x + id;
  277. float4* vals_cast = reinterpret_cast<float4*>(vals);
  278. float4 data[128];
  279. int bid = blockIdx.x;
  280. int tid = threadIdx.x;
  281. curandStatePhilox4_32_10_t state;
  282. curand_init(seed.first, idx, seed.second, &state);
  283. int group_index = bid * token_size + threadIdx.x;
  284. int reg_count = 0;
  285. int total_count = token_size * token_num;
  286. if (group_index < total_count) {
  287. // float min = 10000.0;
  288. float max = -10000.0;
  289. while (tid < token_size) {
  290. data[reg_count] = vals_cast[group_index];
  291. if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
  292. if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
  293. if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
  294. if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
  295. group_index += blockDim.x;
  296. tid += blockDim.x;
  297. reg_count++;
  298. }
  299. #pragma unroll
  300. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  301. auto temp = g.shfl_xor(max, i);
  302. if (max < temp) max = temp;
  303. }
  304. __shared__ float partialMax[WARP_SIZE];
  305. if (lane == 0) partialMax[gid] = max;
  306. b.sync();
  307. if (lane < warp_num) max = partialMax[lane];
  308. #pragma unroll
  309. for (int i = 1; i < warp_num; i <<= 1) {
  310. auto temp = g.shfl_down(max, i);
  311. if (max < temp) max = temp;
  312. }
  313. max = g.shfl(max, 0);
  314. float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
  315. float high_q = (float)((1 << (num_bits - 1)) - 1);
  316. float low_q = (float)(-((1 << (num_bits - 1))));
  317. int offset = (bid)*token_size;
  318. for (int i = 0; i < reg_count; i++) {
  319. group_index = i * blockDim.x + threadIdx.x;
  320. if (group_index < token_size) {
  321. float4 q_data = data[i];
  322. float4 q_data_int;
  323. q_data_int.x = (float)((int)(q_data.x * q_scale_val));
  324. q_data_int.y = (float)((int)(q_data.y * q_scale_val));
  325. q_data_int.w = (float)((int)(q_data.w * q_scale_val));
  326. q_data_int.z = (float)((int)(q_data.z * q_scale_val));
  327. // Stochastic rounding
  328. float4 rand = curand_uniform4(&state);
  329. float q_error[4];
  330. q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
  331. q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
  332. q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
  333. q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
  334. q_data_int.x =
  335. (rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
  336. ? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
  337. : q_data_int.x;
  338. q_data_int.y =
  339. (rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
  340. ? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
  341. : q_data_int.y;
  342. q_data_int.w =
  343. (rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
  344. ? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
  345. : q_data_int.w;
  346. q_data_int.z =
  347. (rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
  348. ? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
  349. : q_data_int.z;
  350. q_data_int.x /= q_scale_val;
  351. q_data_int.y /= q_scale_val;
  352. q_data_int.w /= q_scale_val;
  353. q_data_int.z /= q_scale_val;
  354. vals_cast[group_index + offset] = q_data_int;
  355. }
  356. }
  357. }
  358. }
  359. template <typename T>
  360. void launch_sr_quantize_kernel(T* vals,
  361. int total_count,
  362. int group_num,
  363. int num_bits,
  364. cudaStream_t stream)
  365. {
  366. dim3 block_dim(1024);
  367. dim3 grid_dim(group_num);
  368. uint64_t inc = total_count / grid_dim.x / block_dim.x;
  369. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  370. sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
  371. vals, (total_count / group_num) / 4, group_num, num_bits, seed);
  372. }
  373. template void launch_sr_quantize_kernel(float* vals,
  374. int total_count,
  375. int group_num,
  376. int num_bits,
  377. cudaStream_t stream);
  378. template void launch_sr_quantize_kernel(__half* vals,
  379. int total_count,
  380. int group_num,
  381. int num_bits,
  382. cudaStream_t stream);
  383. __global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
  384. {
  385. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  386. cg::thread_block b = cg::this_thread_block();
  387. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  388. int gid = threadIdx.x >> 5;
  389. int lane = threadIdx.x & 0x1f;
  390. int warp_num = blockDim.x >> 5;
  391. int id = threadIdx.x;
  392. float2* vals_cast = reinterpret_cast<float2*>(vals);
  393. float2 data[MAX_REG];
  394. int group_id = blockIdx.x;
  395. {
  396. int group_index = id;
  397. int reg_count = 0;
  398. int offset = group_id * group_size;
  399. float max = -10000.0;
  400. float min = 10000.0;
  401. while (group_index < group_size && reg_count < MAX_REG) {
  402. data[reg_count] = vals_cast[offset + group_index];
  403. __half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
  404. if (((float)data_h[0]) > max) max = (float)data_h[0];
  405. if (((float)data_h[1]) > max) max = (float)data_h[1];
  406. if (((float)data_h[2]) > max) max = (float)data_h[2];
  407. if (((float)data_h[3]) > max) max = (float)data_h[3];
  408. if (((float)data_h[0]) < min) min = (float)data_h[0];
  409. if (((float)data_h[1]) < min) min = (float)data_h[1];
  410. if (((float)data_h[2]) < min) min = (float)data_h[2];
  411. if (((float)data_h[3]) < min) min = (float)data_h[3];
  412. group_index += blockDim.x;
  413. reg_count++;
  414. }
  415. #pragma unroll
  416. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  417. auto temp = g.shfl_xor(max, i);
  418. if (max < temp) max = temp;
  419. }
  420. #pragma unroll
  421. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  422. auto temp = g.shfl_xor(min, i);
  423. if (min > temp) min = temp;
  424. }
  425. __shared__ float partialMax[WARP_SIZE];
  426. __shared__ float partialMin[WARP_SIZE];
  427. if (lane == 0) partialMax[gid] = max;
  428. if (lane == 0) partialMin[gid] = min;
  429. b.sync();
  430. if (lane < warp_num) max = partialMax[lane];
  431. if (lane < warp_num) min = partialMin[lane];
  432. #pragma unroll
  433. for (int i = 1; i < warp_num; i <<= 1) {
  434. auto temp = g.shfl_down(max, i);
  435. if (max < temp) max = temp;
  436. }
  437. #pragma unroll
  438. for (int i = 1; i < warp_num; i <<= 1) {
  439. auto temp = g.shfl_down(min, i);
  440. if (min > temp) min = temp;
  441. }
  442. max = g.shfl(max, 0);
  443. min = g.shfl(min, 0);
  444. float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
  445. float q_scale_inv = 1 / q_scale;
  446. for (int i = 0; i < reg_count; i++) {
  447. group_index = i * blockDim.x + id;
  448. if (group_index < group_size) {
  449. __half2* data_h = reinterpret_cast<__half2*>(&data[i]);
  450. float2 q_data[2];
  451. q_data[0] = __half22float2(data_h[0]);
  452. q_data[1] = __half22float2(data_h[1]);
  453. float2 q_data_int[2];
  454. q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
  455. q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
  456. q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
  457. q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
  458. q_data_int[0].x = q_data_int[0].x * q_scale + min;
  459. q_data_int[0].y = q_data_int[0].y * q_scale + min;
  460. q_data_int[1].x = q_data_int[1].x * q_scale + min;
  461. q_data_int[1].y = q_data_int[1].y * q_scale + min;
  462. data_h[0] = __float22half2_rn(q_data_int[0]);
  463. data_h[1] = __float22half2_rn(q_data_int[1]);
  464. vals_cast[offset + group_index] = data[i];
  465. }
  466. }
  467. }
  468. #endif
  469. }
  470. __global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
  471. {
  472. cg::thread_block b = cg::this_thread_block();
  473. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  474. int gid = threadIdx.x >> 5;
  475. int lane = threadIdx.x & 0x1f;
  476. int warp_num = blockDim.x >> 5;
  477. int id = threadIdx.x;
  478. float4* vals_cast = reinterpret_cast<float4*>(vals);
  479. float4 data[MAX_REG];
  480. int bid = blockIdx.x;
  481. int group_index = bid * group_size + id;
  482. int reg_count = 0;
  483. float max = -10000.0;
  484. float min = 10000.0;
  485. while (id < group_size && reg_count < MAX_REG) {
  486. float4 data_reg = vals_cast[group_index];
  487. data[reg_count] = data_reg;
  488. if (data_reg.x > max) max = data_reg.x;
  489. if (data_reg.y > max) max = data_reg.y;
  490. if (data_reg.w > max) max = data_reg.w;
  491. if (data_reg.z > max) max = data_reg.z;
  492. if (data_reg.x < min) min = data_reg.x;
  493. if (data_reg.y < min) min = data_reg.y;
  494. if (data_reg.w < min) min = data_reg.w;
  495. if (data_reg.z < min) min = data_reg.z;
  496. group_index += blockDim.x;
  497. id += blockDim.x;
  498. reg_count++;
  499. }
  500. id = threadIdx.x;
  501. #pragma unroll
  502. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  503. auto temp = g.shfl_xor(max, i);
  504. if (max < temp) max = temp;
  505. }
  506. #pragma unroll
  507. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  508. auto temp = g.shfl_xor(min, i);
  509. if (min > temp) min = temp;
  510. }
  511. __shared__ float partialMax[WARP_SIZE];
  512. __shared__ float partialMin[WARP_SIZE];
  513. if (lane == 0) partialMax[gid] = max;
  514. if (lane == 0) partialMin[gid] = min;
  515. b.sync();
  516. if (lane < warp_num) max = partialMax[lane];
  517. if (lane < warp_num) min = partialMin[lane];
  518. #pragma unroll
  519. for (int i = 1; i < warp_num; i <<= 1) {
  520. auto temp = g.shfl_down(max, i);
  521. if (max < temp) max = temp;
  522. }
  523. #pragma unroll
  524. for (int i = 1; i < warp_num; i <<= 1) {
  525. auto temp = g.shfl_down(min, i);
  526. if (min > temp) min = temp;
  527. }
  528. max = g.shfl(max, 0);
  529. min = g.shfl(min, 0);
  530. float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
  531. float q_scale_inv = 1 / q_scale;
  532. for (int i = 0; i < reg_count; i++) {
  533. group_index = i * blockDim.x + id;
  534. if (group_index < group_size) {
  535. float4 q_data;
  536. q_data = data[i];
  537. float4 q_data_int;
  538. q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
  539. q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
  540. q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
  541. q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
  542. q_data.x = q_data_int.x * q_scale + min;
  543. q_data.y = q_data_int.y * q_scale + min;
  544. q_data.w = q_data_int.w * q_scale + min;
  545. q_data.z = q_data_int.z * q_scale + min;
  546. vals_cast[group_index + bid * group_size] = q_data;
  547. }
  548. }
  549. }
  550. template <typename T>
  551. void launch_quantize_kernel_asym(T* vals,
  552. int total_count,
  553. int group_num,
  554. int num_bits,
  555. cudaStream_t stream)
  556. {
  557. dim3 grid_dim(group_num);
  558. dim3 block_dim(1024);
  559. quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
  560. vals, (total_count / group_num) / 4, num_bits);
  561. }
  562. template void launch_quantize_kernel_asym(float* vals,
  563. int total_count,
  564. int group_num,
  565. int num_bits,
  566. cudaStream_t stream);
  567. template void launch_quantize_kernel_asym(__half* vals,
  568. int total_count,
  569. int group_num,
  570. int num_bits,
  571. cudaStream_t stream);
  572. __global__ void sr_quantize_kernel_asym(__half* vals,
  573. int token_size,
  574. int token_num,
  575. int num_bits,
  576. std::pair<uint64_t, uint64_t> seed)
  577. {
  578. #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
  579. cg::thread_block b = cg::this_thread_block();
  580. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  581. int gid = threadIdx.x >> 5;
  582. int lane = threadIdx.x & 0x1f;
  583. int warp_num = blockDim.x >> 5;
  584. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  585. float2* vals_cast = reinterpret_cast<float2*>(vals);
  586. __half2 data_low[128];
  587. __half2 data_high[128];
  588. int bid = blockIdx.x;
  589. curandStatePhilox4_32_10_t state;
  590. curand_init(seed.first, idx, seed.second, &state);
  591. unsigned int tid = threadIdx.x;
  592. int reg_count = 0;
  593. int offset = bid * token_size;
  594. int group_index = bid * token_size + tid;
  595. int total_count = token_size * token_num;
  596. if (group_index < total_count) {
  597. float min = 10000.0;
  598. float max = -10000.0;
  599. while (tid < token_size) {
  600. float2 data = vals_cast[offset + tid];
  601. __half2* data_h = reinterpret_cast<__half2*>(&data);
  602. data_low[reg_count] = data_h[0];
  603. data_high[reg_count] = data_h[1];
  604. float2 data_f[2];
  605. data_f[0] = __half22float2(data_h[0]);
  606. data_f[1] = __half22float2(data_h[1]);
  607. if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
  608. if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
  609. if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
  610. if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
  611. if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
  612. if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
  613. if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
  614. if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
  615. tid += blockDim.x;
  616. reg_count++;
  617. }
  618. #pragma unroll
  619. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  620. auto temp = g.shfl_xor(max, i);
  621. if (max < temp) max = temp;
  622. }
  623. #pragma unroll
  624. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  625. auto temp = g.shfl_xor(min, i);
  626. if (min > temp) min = temp;
  627. }
  628. __shared__ float partialMax[WARP_SIZE];
  629. __shared__ float partialMin[WARP_SIZE];
  630. if (lane == 0) partialMax[gid] = max;
  631. if (lane == 0) partialMin[gid] = min;
  632. b.sync();
  633. if (lane < warp_num) max = partialMax[lane];
  634. if (lane < warp_num) min = partialMin[lane];
  635. #pragma unroll
  636. for (int i = 1; i < warp_num; i <<= 1) {
  637. auto temp = g.shfl_down(max, i);
  638. if (max < temp) max = temp;
  639. }
  640. #pragma unroll
  641. for (int i = 1; i < warp_num; i <<= 1) {
  642. auto temp = g.shfl_down(min, i);
  643. if (min > temp) min = temp;
  644. }
  645. max = g.shfl(max, 0);
  646. min = g.shfl(min, 0);
  647. float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
  648. float q_scale_val_inv = 1 / q_scale_val;
  649. float high_q = (float)((1 << num_bits) - 1);
  650. for (int i = 0; i < reg_count; i++) {
  651. int token_index = i * blockDim.x + threadIdx.x;
  652. if (token_index < token_size) {
  653. float2 data_f[2];
  654. data_f[0] = __half22float2(data_low[i]);
  655. data_f[1] = __half22float2(data_high[i]);
  656. float2 q_data_int[2];
  657. q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
  658. q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
  659. q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
  660. q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
  661. // Stochastic rounding
  662. float4 rand = curand_uniform4(&state);
  663. float q_error[4];
  664. q_error[0] =
  665. abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
  666. q_error[1] =
  667. abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
  668. q_error[2] =
  669. abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
  670. q_error[3] =
  671. abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
  672. q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
  673. ? (q_data_int[0].x + 1)
  674. : q_data_int[0].x;
  675. q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
  676. ? (q_data_int[0].y + 1)
  677. : q_data_int[0].y;
  678. q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
  679. ? (q_data_int[1].x + 1)
  680. : q_data_int[1].x;
  681. q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
  682. ? (q_data_int[1].y + 1)
  683. : q_data_int[1].y;
  684. data_f[0].x = q_data_int[0].x * q_scale_val + min;
  685. data_f[0].y = q_data_int[0].y * q_scale_val + min;
  686. data_f[1].x = q_data_int[1].x * q_scale_val + min;
  687. data_f[1].y = q_data_int[1].y * q_scale_val + min;
  688. float2 result;
  689. __half2* result_h = reinterpret_cast<__half2*>(&result);
  690. result_h[0] = __float22half2_rn(data_f[0]);
  691. result_h[1] = __float22half2_rn(data_f[1]);
  692. vals_cast[offset + token_index] = result;
  693. }
  694. }
  695. }
  696. #endif
  697. }
  698. __global__ void sr_quantize_kernel_asym(float* vals,
  699. int token_size,
  700. int token_num,
  701. int num_bits,
  702. std::pair<uint64_t, uint64_t> seed)
  703. {
  704. cg::thread_block b = cg::this_thread_block();
  705. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  706. int gid = threadIdx.x >> 5;
  707. int lane = threadIdx.x & 0x1f;
  708. int warp_num = blockDim.x >> 5;
  709. int id = threadIdx.x;
  710. int idx = blockIdx.x * blockDim.x + id;
  711. float4* vals_cast = reinterpret_cast<float4*>(vals);
  712. float4 data[128];
  713. int bid = blockIdx.x;
  714. int tid = threadIdx.x;
  715. curandStatePhilox4_32_10_t state;
  716. curand_init(seed.first, idx, seed.second, &state);
  717. int group_index = bid * token_size + threadIdx.x;
  718. int reg_count = 0;
  719. int total_count = token_size * token_num;
  720. if (group_index < total_count) {
  721. float min = 10000.0;
  722. float max = -10000.0;
  723. while (tid < token_size) {
  724. float4 data_reg = vals_cast[group_index];
  725. data[reg_count] = data_reg;
  726. if (data_reg.x > max) max = data_reg.x;
  727. if (data_reg.y > max) max = data_reg.y;
  728. if (data_reg.w > max) max = data_reg.w;
  729. if (data_reg.z > max) max = data_reg.z;
  730. if (data_reg.x < min) min = data_reg.x;
  731. if (data_reg.y < min) min = data_reg.y;
  732. if (data_reg.w < min) min = data_reg.w;
  733. if (data_reg.z < min) min = data_reg.z;
  734. group_index += blockDim.x;
  735. tid += blockDim.x;
  736. reg_count++;
  737. }
  738. #pragma unroll
  739. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  740. auto temp = g.shfl_xor(max, i);
  741. if (max < temp) max = temp;
  742. }
  743. #pragma unroll
  744. for (int i = 1; i < WARP_SIZE; i <<= 1) {
  745. auto temp = g.shfl_xor(min, i);
  746. if (min > temp) min = temp;
  747. }
  748. __shared__ float partialMax[WARP_SIZE];
  749. __shared__ float partialMin[WARP_SIZE];
  750. if (lane == 0) partialMax[gid] = max;
  751. if (lane == 0) partialMin[gid] = min;
  752. b.sync();
  753. if (lane < warp_num) max = partialMax[lane];
  754. if (lane < warp_num) min = partialMin[lane];
  755. #pragma unroll
  756. for (int i = 1; i < warp_num; i <<= 1) {
  757. auto temp = g.shfl_down(max, i);
  758. if (max < temp) max = temp;
  759. }
  760. #pragma unroll
  761. for (int i = 1; i < warp_num; i <<= 1) {
  762. auto temp = g.shfl_down(min, i);
  763. if (min > temp) min = temp;
  764. }
  765. max = g.shfl(max, 0);
  766. min = g.shfl(min, 0);
  767. float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
  768. float high_q = (float)((1 << num_bits) - 1);
  769. int offset = (bid)*token_size;
  770. for (int i = 0; i < reg_count; i++) {
  771. group_index = i * blockDim.x + threadIdx.x;
  772. if (group_index < token_size) {
  773. float4 q_data = data[i];
  774. float4 q_data_int;
  775. q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
  776. q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
  777. q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
  778. q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
  779. // Stochastic rounding
  780. float4 rand = curand_uniform4(&state);
  781. float q_error[4];
  782. q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
  783. q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
  784. q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
  785. q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
  786. q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
  787. : q_data_int.x;
  788. q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
  789. : q_data_int.y;
  790. q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
  791. : q_data_int.w;
  792. q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
  793. : q_data_int.z;
  794. q_data_int.x = q_data_int.x * q_scale_val + min;
  795. q_data_int.y = q_data_int.y * q_scale_val + min;
  796. q_data_int.w = q_data_int.w * q_scale_val + min;
  797. q_data_int.z = q_data_int.z * q_scale_val + min;
  798. vals_cast[group_index + offset] = q_data_int;
  799. }
  800. }
  801. }
  802. }
  803. template <typename T>
  804. void launch_sr_quantize_kernel_asym(T* vals,
  805. int total_count,
  806. int group_num,
  807. int num_bits,
  808. cudaStream_t stream)
  809. {
  810. dim3 block_dim(1024);
  811. dim3 grid_dim(group_num);
  812. uint64_t inc = total_count / grid_dim.x / block_dim.x;
  813. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  814. sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
  815. vals, (total_count / group_num) / 4, group_num, num_bits, seed);
  816. }
  817. template void launch_sr_quantize_kernel_asym(float* vals,
  818. int total_count,
  819. int group_num,
  820. int num_bits,
  821. cudaStream_t stream);
  822. template void launch_sr_quantize_kernel_asym(__half* vals,
  823. int total_count,
  824. int group_num,
  825. int num_bits,
  826. cudaStream_t stream);