123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575 |
- #include "custom_cuda_layers.h"
- #define rows_trans 16
- #define cols_trans 16
- template <typename T>
- __global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
- {
- __shared__ T data_block[rows_trans * (cols_trans + 1)];
- int r = threadIdx.x / cols_trans;
- int c = threadIdx.x % cols_trans;
- int m = row_width / cols_trans;
- int i = blockIdx.x / m * rows_trans + r;
- int j = blockIdx.x % m * cols_trans + c;
- int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
- for (int k = 0; k < rows_trans; k += row_stride)
- data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
- __syncthreads();
- i = blockIdx.x % m * rows_trans + r;
- j = blockIdx.x / m * cols_trans + c;
- for (int k = 0; k < rows_trans; k += row_stride)
- out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
- }
- template <>
- void Transpose<__half>(const __half* inp_mat,
- __half* out_mat,
- int rows,
- int cols,
- cudaStream_t stream)
- {
- int threads = THREADS;
- Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
- inp_mat, out_mat, cols, rows);
- }
- template <>
- void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream)
- {
- int threads = THREADS;
- Transpose_Kernel<float><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
- inp_mat, out_mat, cols, rows);
- }
- template <typename T>
- __global__ void transform_0213(T* output,
- const T* vals,
- int hidden_dim,
- int seq_length,
- int heads,
- int head_ext);
- template <>
- __global__ void transform_0213<float>(float* output,
- const float* vals,
- int hidden_dim,
- int seq_length,
- int heads,
- int head_ext)
- {
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int d0_out_stride = d0_stride;
- int d1_out_stride = d2_stride;
- int d2_out_stride = d2_stride * seq_length;
- int d0 = blockIdx.x; // Batch
- int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
- int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
- int d3 = threadIdx.x; // Values (groups of 4)
- const float4* vals_vec = reinterpret_cast<const float4*>(vals);
- float4* output_vec = reinterpret_cast<float4*>(output);
- float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
- output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
- }
- template <>
- __global__ void transform_0213<__half>(__half* output,
- const __half* vals,
- int hidden_dim,
- int seq_length,
- int heads,
- int head_ext)
- {
- #if __CUDA_ARCH__ >= 700
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int d0_out_stride = d0_stride;
- int d1_out_stride = d2_stride;
- int d2_out_stride = d2_stride * seq_length;
- int d0 = blockIdx.x; // Batch
- int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
- int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
- int d3 = threadIdx.x; // Values (groups of 4)
- float4 vals_arr[1];
- const float4* vals_vec = reinterpret_cast<const float4*>(vals);
- float4* output_vec = reinterpret_cast<float4*>(output);
- vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
- output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
- #endif
- }
- template <>
- void launch_transform_0213<float>(float* output,
- const float* vals,
- int batch_size,
- int seq_length,
- int hidden_dim,
- int heads,
- cudaStream_t stream)
- {
- hidden_dim >>= 2;
- int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
- dim3 block_dim(hidden_dim / heads, (heads / head_ext));
- dim3 grid_dim(batch_size, (seq_length * head_ext));
- transform_0213<float>
- <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
- }
- template <>
- void launch_transform_0213<__half>(__half* output,
- const __half* vals,
- int batch_size,
- int seq_length,
- int hidden_dim,
- int heads,
- cudaStream_t stream)
- {
- hidden_dim >>= 3;
- int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
- dim3 block_dim(hidden_dim / heads, (heads / head_ext));
- dim3 grid_dim(batch_size, (seq_length * head_ext));
- transform_0213<__half>
- <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
- }
- // Bias add
- template <typename T>
- __global__ void bias_add_transform_0213(T* output,
- const T* vals,
- const T* bias,
- int hidden_dim,
- int seq_length,
- int heads,
- int head_ext);
- template <>
- __global__ void bias_add_transform_0213<float>(float* output,
- const float* vals,
- const float* bias,
- int hidden_dim,
- int seq_length,
- int heads,
- int head_ext)
- {
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int d0_out_stride = d0_stride;
- int d1_out_stride = d2_stride;
- int d2_out_stride = d2_stride * seq_length;
- int d0 = blockIdx.x; // Batch
- int d1 = blockIdx.y; // Sequence ID (0-127)
- int cnt = blockIdx.z / head_ext; // Hidden count
- int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
- int d3 = threadIdx.x; // Values (groups of 4)
- const float4* vals_vec = reinterpret_cast<const float4*>(vals);
- const float4* bias_vec = reinterpret_cast<const float4*>(bias);
- float4* output_vec = reinterpret_cast<float4*>(output);
- float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
- d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
- float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
- float4 outputs;
- outputs.x = inputs.x + biases.x;
- outputs.y = inputs.y + biases.y;
- outputs.z = inputs.z + biases.z;
- outputs.w = inputs.w + biases.w;
- output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
- d2 * d2_out_stride + d3] = outputs;
- }
- #define ATTN_H 3
- #define MAX_SEQ_LINE 10
- template <>
- __global__ void bias_add_transform_0213<__half>(__half* output,
- const __half* vals,
- const __half* bias,
- int hidden_dim,
- int seq_length,
- int heads,
- int head_ext)
- {
- #if __CUDA_ARCH__ >= 700
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int d2_out_stride = d2_stride * seq_length;
- int d0 = blockIdx.x; // Batch
- int d1 = blockIdx.y; // Sequence ID (0-127)
- int cnt = blockIdx.z / head_ext; // Hidden count
- int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
- int d3 = threadIdx.x; // Values (groups of 4)
- float4 vals_arr;
- float4 bias_arr;
- float4 output_arr;
- __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
- __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
- __half2* output_half = reinterpret_cast<__half2*>(&output_arr);
- const float4* vals_vec = reinterpret_cast<const float4*>(vals);
- const float4* bias_vec = reinterpret_cast<const float4*>(bias);
- float4* output_vec = reinterpret_cast<float4*>(output);
- vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
- vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
- vals_vec += (cnt * d1_stride);
- vals_vec += (d2 * d2_stride);
- bias_vec += (cnt * d1_stride);
- bias_vec += (d2 * d2_stride);
- output_vec += (cnt * d0_stride * gridDim.x);
- output_vec += (d1 * d2_stride);
- output_vec += (d0 * d0_stride);
- output_vec += (d2 * d2_out_stride);
- bias_arr = bias_vec[d3];
- vals_arr = vals_vec[d3];
- #if defined(__ACC_HALF__)
- output_half[0] = vals_half[0] + bias_half[0];
- output_half[1] = vals_half[1] + bias_half[1];
- output_half[2] = vals_half[2] + bias_half[2];
- output_half[3] = vals_half[3] + bias_half[3];
- #else
- float2 bias_arr_f[4];
- float2 vals_arr_f[4];
- #pragma unroll
- for (int l = 0; l < 4; l++) {
- bias_arr_f[l] = __half22float2(bias_half[l]);
- vals_arr_f[l] = __half22float2(vals_half[l]);
- vals_arr_f[l].x += bias_arr_f[l].x;
- vals_arr_f[l].y += bias_arr_f[l].y;
- output_half[l] = __float22half2_rn(vals_arr_f[l]);
- }
- #endif
- output_vec[d3] = output_arr;
- #endif
- }
- __global__ void bias_add_transform_0213_v2(__half* output,
- const __half* vals,
- const __half* bias,
- int hidden_dim,
- int seq_length,
- int heads)
- {
- #if __CUDA_ARCH__ >= 700
- __shared__ float4 in_data[3072];
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
- int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
- int d0_out_stride = d0_stride;
- int d1_out_stride = d2_stride;
- int d2_out_stride = d2_stride * seq_length;
- int d0 = blockIdx.x; // Batch
- int d1 = blockIdx.y; // Sequence ID (0-127)
- int cnt = threadIdx.z; // blockIdx.z; // Hidden count
- int d2 = threadIdx.y; // Head (0-11)
- int d3 = threadIdx.x; // Values (groups of 4)
- float4 vals_arr[1];
- float4 bias_arr[1];
- float4 output_arr[1];
- __half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
- __half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
- __half2* output_half = reinterpret_cast<__half2*>(output_arr);
- const float4* vals_vec = reinterpret_cast<const float4*>(vals);
- const float4* bias_vec = reinterpret_cast<const float4*>(bias);
- float4* output_vec = reinterpret_cast<float4*>(output);
- int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
- int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
- bias_arr[0] = bias_vec[iter_index];
- #pragma unroll
- for (int iter = 0; iter < 2; iter++) {
- int iter_id = iter * iteration_stride + iter_index;
- vals_arr[0] = vals_vec[input_offset + iter_id];
- output_half[0] = vals_half[0] + bias_half[0];
- output_half[1] = vals_half[1] + bias_half[1];
- output_half[2] = vals_half[2] + bias_half[2];
- output_half[3] = vals_half[3] + bias_half[3];
- in_data[iter_id] = output_arr[0];
- }
- __syncthreads();
- iteration_stride = blockDim.z * (blockDim.y >> 1);
- int matrix_stride = (d0_out_stride * gridDim.x);
- int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
- int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
- #pragma unroll
- for (int iter = 0; iter < 2; iter++) {
- int iter_row = (iter * iteration_stride) + head_count;
- int iter_offset =
- (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
- output_vec[out_index + iter_offset] =
- in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
- }
- #endif
- }
- // [B S C*H] - > C * [B A S N]
- template <>
- void launch_bias_add_transform_0213<float>(float* output,
- const float* vals,
- const float* bias,
- int batch_size,
- int seq_length,
- int hidden_dim,
- int heads,
- cudaStream_t stream,
- int trans_count)
- {
- hidden_dim >>= 2;
- int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
- dim3 block_dim(hidden_dim / heads, (heads / head_ext));
- dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
- bias_add_transform_0213<float><<<grid_dim, block_dim, 0, stream>>>(
- output, vals, bias, hidden_dim, seq_length, heads, head_ext);
- }
- template <>
- void launch_bias_add_transform_0213<__half>(__half* output,
- const __half* vals,
- const __half* bias,
- int batch_size,
- int seq_length,
- int hidden_dim,
- int heads,
- cudaStream_t stream,
- int trans_count)
- {
- hidden_dim >>= 3;
- if (hidden_dim > 128 || hidden_dim < 16) {
- int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
- dim3 block_dim(hidden_dim / heads, (heads / head_ext));
- dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
- bias_add_transform_0213<__half><<<grid_dim, block_dim, 0, stream>>>(
- output, vals, bias, hidden_dim, seq_length, heads, head_ext);
- } else {
- dim3 block_dim(hidden_dim / heads, heads, trans_count);
- dim3 grid_dim(batch_size, seq_length / 2);
- bias_add_transform_0213_v2<<<grid_dim, block_dim, 0, stream>>>(
- output, vals, bias, hidden_dim, seq_length, heads);
- }
- }
- template <typename T>
- __global__ void transform4d_0213(T* out,
- const T* in,
- int heads,
- int seq_length,
- int hidden_dim,
- int head_ext);
- template <>
- __global__ void transform4d_0213<float>(float* out,
- const float* in,
- int heads,
- int seq_length,
- int hidden_dim,
- int head_ext)
- {
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = d0_stride / heads;
- int d2_stride = hidden_dim / heads;
- int d0_out_stride = d0_stride;
- int d1_out_stride = d2_stride;
- int d2_out_stride = hidden_dim;
- int d0 = blockIdx.x; // Batch
- int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
- int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
- int cnt = blockIdx.z;
- int d3 = threadIdx.x; // Values (groups of 8)
- if (d2 < seq_length) {
- const float4* in_vec = reinterpret_cast<const float4*>(in);
- float4* out_vec = reinterpret_cast<float4*>(out);
- float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
- d2 * d2_stride + d3];
- out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
- d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
- }
- }
- template <>
- __global__ void transform4d_0213<__half>(__half* out,
- const __half* in,
- int heads,
- int seq_length,
- int hidden_dim,
- int head_ext)
- {
- #if __CUDA_ARCH__ >= 700
- int d0_stride = hidden_dim * (seq_length / head_ext);
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int d0 = blockIdx.x; // Batch
- int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
- int d2 = blockIdx.z / head_ext; // Sequence
- int cnt = blockIdx.y; // Hidden count
- int d3 = threadIdx.x; // Values (groups of 8)
- const float4* in_vec = reinterpret_cast<const float4*>(in);
- float4* out_vec = reinterpret_cast<float4*>(out);
- in_vec += (cnt * d0_stride * gridDim.x);
- in_vec += (d0 * d0_stride);
- in_vec += (d2 * d2_stride);
- in_vec += (d1 * d2_stride * seq_length);
- out_vec += (cnt * d1_stride);
- out_vec += (d1 * d2_stride);
- out_vec += (d0 * d0_stride * gridDim.y);
- out_vec += (d2 * d1_stride * gridDim.y);
- out_vec[d3] = in_vec[d3];
- #endif
- }
- __global__ void transform4d_0213_v2(__half* out,
- const __half* in,
- int heads,
- int seq_length,
- int hidden_dim)
- {
- #if __CUDA_ARCH__ >= 700
- __shared__ float4 in_data[3072];
- int d0_stride = hidden_dim * seq_length;
- int d1_stride = hidden_dim;
- int d2_stride = hidden_dim / heads;
- int d0 = blockIdx.x; // Batch
- int d1 = threadIdx.y; // Head
- int d2 = blockIdx.y; // Sequence
- int cnt = threadIdx.z; // Hidden count
- int d3 = threadIdx.x; // Values (groups of 8)
- const float4* in_vec = reinterpret_cast<const float4*>(in);
- float4* out_vec = reinterpret_cast<float4*>(out);
- int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
- int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
- int iteration_stride = blockDim.z * (blockDim.y >> 1);
- int matrix_stride = (d0_stride * gridDim.x);
- #pragma unroll
- for (int iter = 0; iter < 2; iter++) {
- int iter_row = iter * iteration_stride + head_count;
- int iter_offset = (iter_row % blockDim.y) * d2_stride;
- in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
- in_vec[input_offset + iter_offset * seq_length +
- (iter_row / blockDim.y) * matrix_stride];
- }
- __syncthreads();
- iteration_stride = d1_stride * blockDim.z;
- int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
- int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
- #pragma unroll
- for (int iter = 0; iter < 2; iter++) {
- int iter_id = iter * iteration_stride + iter_index;
- out_vec[output_offset + iter_id] = in_data[iter_id];
- }
- #endif
- }
- // 3 * [B A S N] - > [B S C*H]
- template <>
- void launch_transform4d_0213<float>(float* out,
- const float* in,
- int batch_size,
- int heads,
- int seq_length,
- int hidden_dim,
- cudaStream_t stream,
- int trans_count)
- {
- hidden_dim >>= 2;
- dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
- dim3 block_dims(hidden_dim / heads, 8);
- transform4d_0213<float>
- <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim, 1);
- }
- template <>
- void launch_transform4d_0213<__half>(__half* out,
- const __half* in,
- int batch_size,
- int heads,
- int seq_length,
- int hidden_dim,
- cudaStream_t stream,
- int trans_count)
- {
- hidden_dim >>= 3;
- if (hidden_dim > 128 || hidden_dim < 16) {
- int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
- dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
- dim3 block_dims(hidden_dim / heads, (heads / head_ext));
- transform4d_0213<__half><<<grid_dims, block_dims, 0, stream>>>(
- out, in, heads, seq_length, hidden_dim, head_ext);
- } else {
- dim3 grid_dims(batch_size, seq_length / 2);
- dim3 block_dims(hidden_dim / heads, heads, trans_count);
- transform4d_0213_v2<<<grid_dims, block_dims, 0, stream>>>(
- out, in, heads, seq_length, hidden_dim);
- }
- }
|