12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022 |
- #include <math.h>
- #include "custom_cuda_layers.h"
- #include "memory_access_utils.h"
- namespace cg = cooperative_groups;
- __global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
- {
- #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
- cg::thread_block b = cg::this_thread_block(); // tb
- cg::thread_block_tile<32> g =
- cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64.
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int id = threadIdx.x;
- constexpr int granularity = 16;
- constexpr int vals_per_access = granularity / sizeof(__half);
- __half data[vals_per_access];
- int group_id = blockIdx.x;
- int thread_index = id * vals_per_access;
- int reg_count = 0;
- int offset = group_id * group_size;
- float max = -10000.0;
- for (int thread_index = id * vals_per_access; thread_index < group_size;
- thread_index += blockDim.x * vals_per_access) {
- mem_access::load_global<granularity>(data, vals + offset + thread_index);
- #pragma unroll
- for (int i = 0; i < vals_per_access; i++) {
- if (abs((float)data[i]) > max) max = abs((float)data[i]);
- }
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- max = g.shfl(max, 0);
- float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5);
- float q_scale_inv = 1 / q_scale;
- int q_range_max = (1 << (num_bits - 1)) - 1;
- int q_range_min = -(1 << (num_bits - 1));
- for (int thread_index = id * vals_per_access; thread_index < group_size;
- thread_index += blockDim.x * vals_per_access) {
- mem_access::load_global<granularity>(data, vals + offset + thread_index);
- #pragma unroll
- for (int j = 0; j < vals_per_access; j++) {
- float q_data;
- q_data = __half2float(data[j]);
- q_data = __float2int_rn(q_data * q_scale);
- q_data = q_data > (q_range_max) ? (q_range_max)
- : (q_data < (q_range_min) ? (q_range_min) : q_data);
- data[j] = __float2half_rn(q_data * q_scale_inv);
- }
- mem_access::store_global<granularity>(vals + offset + thread_index, data);
- }
- #endif
- }
- __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
- {
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int id = threadIdx.x;
- constexpr int granularity = 16;
- constexpr int vals_per_access = granularity / sizeof(float);
- float data[vals_per_access];
- int bid = blockIdx.x;
- int thread_index = id * vals_per_access;
- int reg_count = 0;
- int offset = bid * group_size;
- float max = -10000.0;
- for (int thread_index = id * vals_per_access; thread_index < group_size;
- thread_index += blockDim.x * vals_per_access) {
- mem_access::load_global<granularity>(data, vals + offset + thread_index);
- #pragma unroll
- for (int i = 0; i < vals_per_access; i++) {
- if (abs(data[i]) > max) max = abs(data[i]);
- }
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- b.sync();
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- max = g.shfl(max, 0);
- float q_scale = (1 << num_bits) / (2 * max + 1e-5);
- float q_scale_inv = 1 / q_scale;
- int q_range_max = (1 << (num_bits - 1)) - 1;
- int q_range_min = -(1 << (num_bits - 1));
- for (int thread_index = id * vals_per_access; thread_index < group_size;
- thread_index += blockDim.x * vals_per_access) {
- mem_access::load_global<granularity>(data, vals + offset + thread_index);
- #pragma unroll
- for (int j = 0; j < vals_per_access; j++) {
- float q_data;
- q_data = __float2int_rn(data[j] * q_scale);
- q_data = q_data > (q_range_max) ? (q_range_max)
- : (q_data < (q_range_min) ? (q_range_min) : q_data);
- data[j] = roundf(q_data * q_scale_inv);
- }
- mem_access::store_global<granularity>(vals + offset + thread_index, data);
- }
- }
- template <typename T>
- void launch_quantize_kernel(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream)
- {
- dim3 grid_dim(group_num);
- dim3 block_dim(1024);
- quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(vals, total_count / group_num, num_bits);
- }
- template void launch_quantize_kernel(float* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template void launch_quantize_kernel(__half* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- __global__ void sr_quantize_kernel(__half* vals,
- int token_size,
- int token_num,
- int num_bits,
- std::pair<uint64_t, uint64_t> seed)
- {
- #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- float2* vals_cast = reinterpret_cast<float2*>(vals);
- __half2 data_low[128];
- __half2 data_high[128];
- int bid = blockIdx.x;
- curandStatePhilox4_32_10_t state;
- curand_init(seed.first, idx, seed.second, &state);
- unsigned int tid = threadIdx.x;
- int reg_count = 0;
- int offset = bid * token_size;
- int group_index = bid * token_size + tid;
- int total_count = token_size * token_num;
- if (group_index < total_count) {
- // float min = 10000.0;
- float max = -10000.0;
- while (tid < token_size) {
- float2 data = vals_cast[offset + tid];
- __half2* data_h = reinterpret_cast<__half2*>(&data);
- data_low[reg_count] = data_h[0];
- data_high[reg_count] = data_h[1];
- float2 data_f[2];
- data_f[0] = __half22float2(data_h[0]);
- data_f[1] = __half22float2(data_h[1]);
- if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
- if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
- if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
- if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
- tid += blockDim.x;
- reg_count++;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- max = g.shfl(max, 0);
- float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
- float high_q = (float)((1 << (num_bits - 1)) - 1);
- float low_q = (float)(-((1 << (num_bits - 1))));
- for (int i = 0; i < reg_count; i++) {
- int token_index = i * blockDim.x + threadIdx.x;
- if (token_index < token_size) {
- float2 data_f[2];
- data_f[0] = __half22float2(data_low[i]);
- data_f[1] = __half22float2(data_high[i]);
- float2 q_data_int[2];
- q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
- q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
- q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
- q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
- // Stochastic rounding
- float4 rand = curand_uniform4(&state);
- float q_error[4];
- q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
- q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
- q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
- q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
- q_data_int[0].x =
- (rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
- ? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
- : q_data_int[0].x;
- q_data_int[0].y =
- (rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
- ? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
- : q_data_int[0].y;
- q_data_int[1].x =
- (rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
- ? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
- : q_data_int[1].x;
- q_data_int[1].y =
- (rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
- ? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
- : q_data_int[1].y;
- data_f[0].x = q_data_int[0].x / q_scale_val;
- data_f[0].y = q_data_int[0].y / q_scale_val;
- data_f[1].x = q_data_int[1].x / q_scale_val;
- data_f[1].y = q_data_int[1].y / q_scale_val;
- float2 result;
- __half2* result_h = reinterpret_cast<__half2*>(&result);
- result_h[0] = __float22half2_rn(data_f[0]);
- result_h[1] = __float22half2_rn(data_f[1]);
- vals_cast[offset + token_index] = result;
- }
- }
- }
- #endif
- }
- __global__ void sr_quantize_kernel(float* vals,
- int token_size,
- int token_num,
- int num_bits,
- std::pair<uint64_t, uint64_t> seed)
- {
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int id = threadIdx.x;
- int idx = blockIdx.x * blockDim.x + id;
- float4* vals_cast = reinterpret_cast<float4*>(vals);
- float4 data[128];
- int bid = blockIdx.x;
- int tid = threadIdx.x;
- curandStatePhilox4_32_10_t state;
- curand_init(seed.first, idx, seed.second, &state);
- int group_index = bid * token_size + threadIdx.x;
- int reg_count = 0;
- int total_count = token_size * token_num;
- if (group_index < total_count) {
- // float min = 10000.0;
- float max = -10000.0;
- while (tid < token_size) {
- data[reg_count] = vals_cast[group_index];
- if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
- if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
- if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
- if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
- group_index += blockDim.x;
- tid += blockDim.x;
- reg_count++;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- max = g.shfl(max, 0);
- float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
- float high_q = (float)((1 << (num_bits - 1)) - 1);
- float low_q = (float)(-((1 << (num_bits - 1))));
- int offset = (bid)*token_size;
- for (int i = 0; i < reg_count; i++) {
- group_index = i * blockDim.x + threadIdx.x;
- if (group_index < token_size) {
- float4 q_data = data[i];
- float4 q_data_int;
- q_data_int.x = (float)((int)(q_data.x * q_scale_val));
- q_data_int.y = (float)((int)(q_data.y * q_scale_val));
- q_data_int.w = (float)((int)(q_data.w * q_scale_val));
- q_data_int.z = (float)((int)(q_data.z * q_scale_val));
- // Stochastic rounding
- float4 rand = curand_uniform4(&state);
- float q_error[4];
- q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
- q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
- q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
- q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
- q_data_int.x =
- (rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
- ? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
- : q_data_int.x;
- q_data_int.y =
- (rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
- ? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
- : q_data_int.y;
- q_data_int.w =
- (rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
- ? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
- : q_data_int.w;
- q_data_int.z =
- (rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
- ? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
- : q_data_int.z;
- q_data_int.x /= q_scale_val;
- q_data_int.y /= q_scale_val;
- q_data_int.w /= q_scale_val;
- q_data_int.z /= q_scale_val;
- vals_cast[group_index + offset] = q_data_int;
- }
- }
- }
- }
- template <typename T>
- void launch_sr_quantize_kernel(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream)
- {
- dim3 block_dim(1024);
- dim3 grid_dim(group_num);
- uint64_t inc = total_count / grid_dim.x / block_dim.x;
- std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
- sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
- vals, (total_count / group_num) / 4, group_num, num_bits, seed);
- }
- template void launch_sr_quantize_kernel(float* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template void launch_sr_quantize_kernel(__half* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- __global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
- {
- #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int id = threadIdx.x;
- float2* vals_cast = reinterpret_cast<float2*>(vals);
- float2 data[MAX_REG];
- int group_id = blockIdx.x;
- {
- int group_index = id;
- int reg_count = 0;
- int offset = group_id * group_size;
- float max = -10000.0;
- float min = 10000.0;
- while (group_index < group_size && reg_count < MAX_REG) {
- data[reg_count] = vals_cast[offset + group_index];
- __half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
- if (((float)data_h[0]) > max) max = (float)data_h[0];
- if (((float)data_h[1]) > max) max = (float)data_h[1];
- if (((float)data_h[2]) > max) max = (float)data_h[2];
- if (((float)data_h[3]) > max) max = (float)data_h[3];
- if (((float)data_h[0]) < min) min = (float)data_h[0];
- if (((float)data_h[1]) < min) min = (float)data_h[1];
- if (((float)data_h[2]) < min) min = (float)data_h[2];
- if (((float)data_h[3]) < min) min = (float)data_h[3];
- group_index += blockDim.x;
- reg_count++;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(min, i);
- if (min > temp) min = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- __shared__ float partialMin[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- if (lane == 0) partialMin[gid] = min;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- if (lane < warp_num) min = partialMin[lane];
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(min, i);
- if (min > temp) min = temp;
- }
- max = g.shfl(max, 0);
- min = g.shfl(min, 0);
- float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
- float q_scale_inv = 1 / q_scale;
- for (int i = 0; i < reg_count; i++) {
- group_index = i * blockDim.x + id;
- if (group_index < group_size) {
- __half2* data_h = reinterpret_cast<__half2*>(&data[i]);
- float2 q_data[2];
- q_data[0] = __half22float2(data_h[0]);
- q_data[1] = __half22float2(data_h[1]);
- float2 q_data_int[2];
- q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
- q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
- q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
- q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
- q_data_int[0].x = q_data_int[0].x * q_scale + min;
- q_data_int[0].y = q_data_int[0].y * q_scale + min;
- q_data_int[1].x = q_data_int[1].x * q_scale + min;
- q_data_int[1].y = q_data_int[1].y * q_scale + min;
- data_h[0] = __float22half2_rn(q_data_int[0]);
- data_h[1] = __float22half2_rn(q_data_int[1]);
- vals_cast[offset + group_index] = data[i];
- }
- }
- }
- #endif
- }
- __global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
- {
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int id = threadIdx.x;
- float4* vals_cast = reinterpret_cast<float4*>(vals);
- float4 data[MAX_REG];
- int bid = blockIdx.x;
- int group_index = bid * group_size + id;
- int reg_count = 0;
- float max = -10000.0;
- float min = 10000.0;
- while (id < group_size && reg_count < MAX_REG) {
- float4 data_reg = vals_cast[group_index];
- data[reg_count] = data_reg;
- if (data_reg.x > max) max = data_reg.x;
- if (data_reg.y > max) max = data_reg.y;
- if (data_reg.w > max) max = data_reg.w;
- if (data_reg.z > max) max = data_reg.z;
- if (data_reg.x < min) min = data_reg.x;
- if (data_reg.y < min) min = data_reg.y;
- if (data_reg.w < min) min = data_reg.w;
- if (data_reg.z < min) min = data_reg.z;
- group_index += blockDim.x;
- id += blockDim.x;
- reg_count++;
- }
- id = threadIdx.x;
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(min, i);
- if (min > temp) min = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- __shared__ float partialMin[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- if (lane == 0) partialMin[gid] = min;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- if (lane < warp_num) min = partialMin[lane];
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(min, i);
- if (min > temp) min = temp;
- }
- max = g.shfl(max, 0);
- min = g.shfl(min, 0);
- float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
- float q_scale_inv = 1 / q_scale;
- for (int i = 0; i < reg_count; i++) {
- group_index = i * blockDim.x + id;
- if (group_index < group_size) {
- float4 q_data;
- q_data = data[i];
- float4 q_data_int;
- q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
- q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
- q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
- q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
- q_data.x = q_data_int.x * q_scale + min;
- q_data.y = q_data_int.y * q_scale + min;
- q_data.w = q_data_int.w * q_scale + min;
- q_data.z = q_data_int.z * q_scale + min;
- vals_cast[group_index + bid * group_size] = q_data;
- }
- }
- }
- template <typename T>
- void launch_quantize_kernel_asym(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream)
- {
- dim3 grid_dim(group_num);
- dim3 block_dim(1024);
- quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
- vals, (total_count / group_num) / 4, num_bits);
- }
- template void launch_quantize_kernel_asym(float* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template void launch_quantize_kernel_asym(__half* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- __global__ void sr_quantize_kernel_asym(__half* vals,
- int token_size,
- int token_num,
- int num_bits,
- std::pair<uint64_t, uint64_t> seed)
- {
- #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- float2* vals_cast = reinterpret_cast<float2*>(vals);
- __half2 data_low[128];
- __half2 data_high[128];
- int bid = blockIdx.x;
- curandStatePhilox4_32_10_t state;
- curand_init(seed.first, idx, seed.second, &state);
- unsigned int tid = threadIdx.x;
- int reg_count = 0;
- int offset = bid * token_size;
- int group_index = bid * token_size + tid;
- int total_count = token_size * token_num;
- if (group_index < total_count) {
- float min = 10000.0;
- float max = -10000.0;
- while (tid < token_size) {
- float2 data = vals_cast[offset + tid];
- __half2* data_h = reinterpret_cast<__half2*>(&data);
- data_low[reg_count] = data_h[0];
- data_high[reg_count] = data_h[1];
- float2 data_f[2];
- data_f[0] = __half22float2(data_h[0]);
- data_f[1] = __half22float2(data_h[1]);
- if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
- if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
- if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
- if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
- if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
- if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
- if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
- if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
- tid += blockDim.x;
- reg_count++;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(min, i);
- if (min > temp) min = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- __shared__ float partialMin[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- if (lane == 0) partialMin[gid] = min;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- if (lane < warp_num) min = partialMin[lane];
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(min, i);
- if (min > temp) min = temp;
- }
- max = g.shfl(max, 0);
- min = g.shfl(min, 0);
- float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
- float q_scale_val_inv = 1 / q_scale_val;
- float high_q = (float)((1 << num_bits) - 1);
- for (int i = 0; i < reg_count; i++) {
- int token_index = i * blockDim.x + threadIdx.x;
- if (token_index < token_size) {
- float2 data_f[2];
- data_f[0] = __half22float2(data_low[i]);
- data_f[1] = __half22float2(data_high[i]);
- float2 q_data_int[2];
- q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
- q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
- q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
- q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
- // Stochastic rounding
- float4 rand = curand_uniform4(&state);
- float q_error[4];
- q_error[0] =
- abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
- q_error[1] =
- abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
- q_error[2] =
- abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
- q_error[3] =
- abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
- q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
- ? (q_data_int[0].x + 1)
- : q_data_int[0].x;
- q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
- ? (q_data_int[0].y + 1)
- : q_data_int[0].y;
- q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
- ? (q_data_int[1].x + 1)
- : q_data_int[1].x;
- q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
- ? (q_data_int[1].y + 1)
- : q_data_int[1].y;
- data_f[0].x = q_data_int[0].x * q_scale_val + min;
- data_f[0].y = q_data_int[0].y * q_scale_val + min;
- data_f[1].x = q_data_int[1].x * q_scale_val + min;
- data_f[1].y = q_data_int[1].y * q_scale_val + min;
- float2 result;
- __half2* result_h = reinterpret_cast<__half2*>(&result);
- result_h[0] = __float22half2_rn(data_f[0]);
- result_h[1] = __float22half2_rn(data_f[1]);
- vals_cast[offset + token_index] = result;
- }
- }
- }
- #endif
- }
- __global__ void sr_quantize_kernel_asym(float* vals,
- int token_size,
- int token_num,
- int num_bits,
- std::pair<uint64_t, uint64_t> seed)
- {
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int gid = threadIdx.x >> 5;
- int lane = threadIdx.x & 0x1f;
- int warp_num = blockDim.x >> 5;
- int id = threadIdx.x;
- int idx = blockIdx.x * blockDim.x + id;
- float4* vals_cast = reinterpret_cast<float4*>(vals);
- float4 data[128];
- int bid = blockIdx.x;
- int tid = threadIdx.x;
- curandStatePhilox4_32_10_t state;
- curand_init(seed.first, idx, seed.second, &state);
- int group_index = bid * token_size + threadIdx.x;
- int reg_count = 0;
- int total_count = token_size * token_num;
- if (group_index < total_count) {
- float min = 10000.0;
- float max = -10000.0;
- while (tid < token_size) {
- float4 data_reg = vals_cast[group_index];
- data[reg_count] = data_reg;
- if (data_reg.x > max) max = data_reg.x;
- if (data_reg.y > max) max = data_reg.y;
- if (data_reg.w > max) max = data_reg.w;
- if (data_reg.z > max) max = data_reg.z;
- if (data_reg.x < min) min = data_reg.x;
- if (data_reg.y < min) min = data_reg.y;
- if (data_reg.w < min) min = data_reg.w;
- if (data_reg.z < min) min = data_reg.z;
- group_index += blockDim.x;
- tid += blockDim.x;
- reg_count++;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < WARP_SIZE; i <<= 1) {
- auto temp = g.shfl_xor(min, i);
- if (min > temp) min = temp;
- }
- __shared__ float partialMax[WARP_SIZE];
- __shared__ float partialMin[WARP_SIZE];
- if (lane == 0) partialMax[gid] = max;
- if (lane == 0) partialMin[gid] = min;
- b.sync();
- if (lane < warp_num) max = partialMax[lane];
- if (lane < warp_num) min = partialMin[lane];
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(max, i);
- if (max < temp) max = temp;
- }
- #pragma unroll
- for (int i = 1; i < warp_num; i <<= 1) {
- auto temp = g.shfl_down(min, i);
- if (min > temp) min = temp;
- }
- max = g.shfl(max, 0);
- min = g.shfl(min, 0);
- float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
- float high_q = (float)((1 << num_bits) - 1);
- int offset = (bid)*token_size;
- for (int i = 0; i < reg_count; i++) {
- group_index = i * blockDim.x + threadIdx.x;
- if (group_index < token_size) {
- float4 q_data = data[i];
- float4 q_data_int;
- q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
- q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
- q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
- q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
- // Stochastic rounding
- float4 rand = curand_uniform4(&state);
- float q_error[4];
- q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
- q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
- q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
- q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
- q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
- : q_data_int.x;
- q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
- : q_data_int.y;
- q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
- : q_data_int.w;
- q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
- : q_data_int.z;
- q_data_int.x = q_data_int.x * q_scale_val + min;
- q_data_int.y = q_data_int.y * q_scale_val + min;
- q_data_int.w = q_data_int.w * q_scale_val + min;
- q_data_int.z = q_data_int.z * q_scale_val + min;
- vals_cast[group_index + offset] = q_data_int;
- }
- }
- }
- }
- template <typename T>
- void launch_sr_quantize_kernel_asym(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream)
- {
- dim3 block_dim(1024);
- dim3 grid_dim(group_num);
- uint64_t inc = total_count / grid_dim.x / block_dim.x;
- std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
- sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
- vals, (total_count / group_num) / 4, group_num, num_bits, seed);
- }
- template void launch_sr_quantize_kernel_asym(float* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template void launch_sr_quantize_kernel_asym(__half* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
|