general_kernels.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. #include "general_kernels.h"
  2. namespace cg = cooperative_groups;
  3. template <typename T>
  4. __global__ void column_sum_reduce(const T* __restrict__ inp,
  5. T* __restrict__ out,
  6. int rows,
  7. int width)
  8. {
  9. __shared__ float tile[TILE_DIM][TILE_DIM + 1];
  10. cg::thread_block b = cg::this_thread_block();
  11. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  12. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  13. int y_stride = width * TILE_DIM;
  14. float localSum = 0;
  15. // Loop across matrix height
  16. if (idx < width) {
  17. int offset = threadIdx.y * width + idx;
  18. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  19. localSum += (float)inp[offset];
  20. offset += y_stride;
  21. }
  22. }
  23. tile[threadIdx.x][threadIdx.y] = localSum;
  24. __syncthreads();
  25. // Sum the shared buffer.
  26. float sum = tile[threadIdx.y][threadIdx.x];
  27. #ifndef __STOCHASTIC_MODE__
  28. __syncthreads();
  29. #endif
  30. for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
  31. if (threadIdx.x == 0) {
  32. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  33. if (pos < width) out[pos] = sum;
  34. }
  35. }
  36. template <typename T>
  37. void launch_fuse_transpose_bias_kernel(const T* inp,
  38. T* out,
  39. int rows,
  40. int cols,
  41. cudaStream_t stream);
  42. template <>
  43. void launch_fuse_transpose_bias_kernel<float>(const float* inp,
  44. float* out,
  45. int rows,
  46. int cols,
  47. cudaStream_t stream)
  48. {
  49. // assert(rows % TILE_DIM == 0);
  50. // assert(cols % TILE_DIM == 0);
  51. dim3 grid_dim((cols - 1) / TILE_DIM + 1);
  52. dim3 block_dim(TILE_DIM, TILE_DIM);
  53. column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
  54. }
  55. template <>
  56. void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
  57. __half* out,
  58. int rows,
  59. int cols,
  60. cudaStream_t stream)
  61. {
  62. // assert(rows % TILE_DIM == 0);
  63. // assert(cols % TILE_DIM == 0);
  64. dim3 grid_dim((cols - 1) / TILE_DIM + 1);
  65. dim3 block_dim(TILE_DIM, TILE_DIM);
  66. column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
  67. }
  68. __global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
  69. {
  70. const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
  71. const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
  72. float4* out_4 = reinterpret_cast<float4*>(out);
  73. CUDA_1D_KERNEL_LOOP(j, N)
  74. {
  75. float4 val;
  76. float4 inp1_reg = inp1_4[j];
  77. float4 inp2_reg = inp2_4[j];
  78. val.x = inp1_reg.x + inp2_reg.x;
  79. val.y = inp1_reg.y + inp2_reg.y;
  80. val.z = inp1_reg.z + inp2_reg.z;
  81. val.w = inp1_reg.w + inp2_reg.w;
  82. out_4[j] = val;
  83. }
  84. }
  85. __global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
  86. {
  87. float2 inp1_4;
  88. float2 inp2_4;
  89. __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
  90. __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
  91. const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
  92. const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
  93. CUDA_1D_KERNEL_LOOP(j, N)
  94. {
  95. inp1_4 = inp1_arr[j];
  96. inp2_4 = inp2_arr[j];
  97. float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
  98. float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
  99. float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
  100. float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
  101. inp1_h_f_0.x += inp2_h_f_0.x;
  102. inp1_h_f_0.y += inp2_h_f_0.y;
  103. inp1_h_f_1.x += inp2_h_f_1.x;
  104. inp1_h_f_1.y += inp2_h_f_1.y;
  105. float2 val_f;
  106. __half2* val_h = reinterpret_cast<__half2*>(&val_f);
  107. val_h[0] = __float22half2_rn(inp1_h_f_0);
  108. val_h[1] = __float22half2_rn(inp1_h_f_1);
  109. float2* out_4 = reinterpret_cast<float2*>(out);
  110. out_4[j] = val_f;
  111. }
  112. }
  113. template <>
  114. void launch_fused_add2<float>(float* out,
  115. const float* inp1,
  116. const float* inp2,
  117. int batch_size,
  118. int seq_length,
  119. int hidden_dim,
  120. cudaStream_t& stream)
  121. {
  122. int total_count = batch_size * seq_length * hidden_dim / 4;
  123. dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
  124. dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
  125. fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
  126. }
  127. template <>
  128. void launch_fused_add2<__half>(__half* out,
  129. const __half* inp1,
  130. const __half* inp2,
  131. int batch_size,
  132. int seq_length,
  133. int hidden_dim,
  134. cudaStream_t& stream)
  135. {
  136. int total_count = batch_size * seq_length * hidden_dim / 4;
  137. dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
  138. dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
  139. fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
  140. }
  141. __global__ void fused_add3_kernel(float* out,
  142. const float* inp1,
  143. const float* inp2,
  144. const float* inp3,
  145. int size,
  146. int row_stride)
  147. {
  148. int row = blockIdx.x;
  149. int id = threadIdx.x;
  150. const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
  151. const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
  152. const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
  153. float4* out_4 = reinterpret_cast<float4*>(out);
  154. float4 val;
  155. float4 inp1_reg = inp1_4[row * row_stride + id];
  156. float4 inp2_reg = inp2_4[row * row_stride + id];
  157. float4 inp3_reg = inp3_4[row * row_stride + id];
  158. val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
  159. val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
  160. val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
  161. val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
  162. out_4[row * row_stride + id] = val;
  163. }
  164. __global__ void fused_add3_kernel(__half* out,
  165. const __half* inp1,
  166. const __half* inp2,
  167. const __half* inp3,
  168. int size,
  169. int row_stride)
  170. {
  171. int row = blockIdx.x;
  172. int id = threadIdx.x;
  173. const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
  174. const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
  175. const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
  176. float2 inp1_4 = inp1_arr[row * row_stride + id];
  177. float2 inp2_4 = inp2_arr[row * row_stride + id];
  178. float2 inp3_4 = inp3_arr[row * row_stride + id];
  179. __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
  180. __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
  181. __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
  182. float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
  183. float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
  184. float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
  185. float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
  186. float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
  187. float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
  188. inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
  189. inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
  190. inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
  191. inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
  192. float2 val_f;
  193. __half2* val_h = reinterpret_cast<__half2*>(&val_f);
  194. val_h[0] = __float22half2_rn(inp1_h_f_0);
  195. val_h[1] = __float22half2_rn(inp1_h_f_1);
  196. float2* out_4 = reinterpret_cast<float2*>(out);
  197. out_4[row * row_stride + id] = val_f;
  198. }
  199. template <>
  200. void launch_fused_add3<float>(float* out,
  201. const float* inp1,
  202. const float* inp2,
  203. const float* inp3,
  204. int batch_size,
  205. int seq_length,
  206. int hidden_size,
  207. cudaStream_t& stream)
  208. {
  209. dim3 grid_dim(batch_size * seq_length);
  210. dim3 block_dim(hidden_size / 4);
  211. fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
  212. out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
  213. }
  214. template <>
  215. void launch_fused_add3<__half>(__half* out,
  216. const __half* inp1,
  217. const __half* inp2,
  218. const __half* inp3,
  219. int batch_size,
  220. int seq_length,
  221. int hidden_size,
  222. cudaStream_t& stream)
  223. {
  224. dim3 grid_dim(batch_size * seq_length);
  225. dim3 block_dim(hidden_size / 4);
  226. fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
  227. out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
  228. }
  229. __global__ void fused_add4_kernel(float* out,
  230. const float* inp1,
  231. const float* inp2,
  232. const float* inp3,
  233. const float* inp4,
  234. int size,
  235. int row_stride)
  236. {
  237. int row = blockIdx.x;
  238. int id = threadIdx.x;
  239. const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
  240. const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
  241. const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
  242. const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
  243. float4* out_4 = reinterpret_cast<float4*>(out);
  244. float4 val;
  245. float4 inp1_reg = inp1_4[row * row_stride + id];
  246. float4 inp2_reg = inp2_4[row * row_stride + id];
  247. float4 inp3_reg = inp3_4[row * row_stride + id];
  248. float4 inp4_reg = inp4_4[row * row_stride + id];
  249. val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
  250. val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
  251. val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
  252. val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
  253. out_4[row * row_stride + id] = val;
  254. }
  255. __global__ void fused_add4_kernel(__half* out,
  256. const __half* inp1,
  257. const __half* inp2,
  258. const __half* inp3,
  259. const __half* inp4,
  260. int size,
  261. int row_stride)
  262. {
  263. int row = blockIdx.x;
  264. int id = threadIdx.x;
  265. const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
  266. const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
  267. const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
  268. const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
  269. float2 inp1_4 = inp1_arr[row * row_stride + id];
  270. float2 inp2_4 = inp2_arr[row * row_stride + id];
  271. float2 inp3_4 = inp3_arr[row * row_stride + id];
  272. float2 inp4_4 = inp4_arr[row * row_stride + id];
  273. __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
  274. __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
  275. __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
  276. __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
  277. float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
  278. float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
  279. float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
  280. float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
  281. float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
  282. float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
  283. float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
  284. float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
  285. inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
  286. inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
  287. inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
  288. inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
  289. float2 val_f;
  290. __half2* val_h = reinterpret_cast<__half2*>(&val_f);
  291. val_h[0] = __float22half2_rn(inp1_h_f_0);
  292. val_h[1] = __float22half2_rn(inp1_h_f_1);
  293. float2* out_4 = reinterpret_cast<float2*>(out);
  294. out_4[row * row_stride + id] = val_f;
  295. }
  296. template <>
  297. void launch_fused_add4<float>(float* out,
  298. const float* inp1,
  299. const float* inp2,
  300. const float* inp3,
  301. const float* inp4,
  302. int batch_size,
  303. int seq_length,
  304. int hidden_size,
  305. cudaStream_t& stream)
  306. {
  307. dim3 grid_dim(batch_size * seq_length);
  308. dim3 block_dim(hidden_size / 4);
  309. fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
  310. out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
  311. }
  312. template <>
  313. void launch_fused_add4<__half>(__half* out,
  314. const __half* inp1,
  315. const __half* inp2,
  316. const __half* inp3,
  317. const __half* inp4,
  318. int batch_size,
  319. int seq_length,
  320. int hidden_size,
  321. cudaStream_t& stream)
  322. {
  323. dim3 grid_dim(batch_size * seq_length);
  324. dim3 block_dim(hidden_size / 4);
  325. fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
  326. out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
  327. }