general_kernels.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "general_kernels.h"
  5. namespace cg = cooperative_groups;
  6. template <typename T>
  7. __global__ void column_sum_reduce(const T* __restrict__ inp,
  8. T* __restrict__ out,
  9. int rows,
  10. int width)
  11. {
  12. __shared__ float tile[TILE_DIM][TILE_DIM + 1];
  13. cg::thread_block b = cg::this_thread_block();
  14. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  15. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  16. int y_stride = width * TILE_DIM;
  17. float localSum = 0;
  18. // Loop across matrix height
  19. if (idx < width) {
  20. int offset = threadIdx.y * width + idx;
  21. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  22. localSum += (float)inp[offset];
  23. offset += y_stride;
  24. }
  25. }
  26. tile[threadIdx.x][threadIdx.y] = localSum;
  27. __syncthreads();
  28. // Sum the shared buffer.
  29. float sum = tile[threadIdx.y][threadIdx.x];
  30. #ifndef __STOCHASTIC_MODE__
  31. __syncthreads();
  32. #endif
  33. for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
  34. if (threadIdx.x == 0) {
  35. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  36. if (pos < width) out[pos] = sum;
  37. }
  38. }
  39. template <typename T>
  40. void launch_fuse_transpose_bias_kernel(const T* inp,
  41. T* out,
  42. int rows,
  43. int cols,
  44. cudaStream_t stream);
  45. template <>
  46. void launch_fuse_transpose_bias_kernel<float>(const float* inp,
  47. float* out,
  48. int rows,
  49. int cols,
  50. cudaStream_t stream)
  51. {
  52. // assert(rows % TILE_DIM == 0);
  53. // assert(cols % TILE_DIM == 0);
  54. dim3 grid_dim((cols - 1) / TILE_DIM + 1);
  55. dim3 block_dim(TILE_DIM, TILE_DIM);
  56. column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
  57. }
  58. template <>
  59. void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
  60. __half* out,
  61. int rows,
  62. int cols,
  63. cudaStream_t stream)
  64. {
  65. // assert(rows % TILE_DIM == 0);
  66. // assert(cols % TILE_DIM == 0);
  67. dim3 grid_dim((cols - 1) / TILE_DIM + 1);
  68. dim3 block_dim(TILE_DIM, TILE_DIM);
  69. column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
  70. }
  71. __global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
  72. {
  73. const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
  74. const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
  75. float4* out_4 = reinterpret_cast<float4*>(out);
  76. CUDA_1D_KERNEL_LOOP(j, N)
  77. {
  78. float4 val;
  79. float4 inp1_reg = inp1_4[j];
  80. float4 inp2_reg = inp2_4[j];
  81. val.x = inp1_reg.x + inp2_reg.x;
  82. val.y = inp1_reg.y + inp2_reg.y;
  83. val.z = inp1_reg.z + inp2_reg.z;
  84. val.w = inp1_reg.w + inp2_reg.w;
  85. out_4[j] = val;
  86. }
  87. }
  88. __global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
  89. {
  90. float2 inp1_4;
  91. float2 inp2_4;
  92. __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
  93. __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
  94. const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
  95. const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
  96. CUDA_1D_KERNEL_LOOP(j, N)
  97. {
  98. inp1_4 = inp1_arr[j];
  99. inp2_4 = inp2_arr[j];
  100. float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
  101. float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
  102. float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
  103. float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
  104. inp1_h_f_0.x += inp2_h_f_0.x;
  105. inp1_h_f_0.y += inp2_h_f_0.y;
  106. inp1_h_f_1.x += inp2_h_f_1.x;
  107. inp1_h_f_1.y += inp2_h_f_1.y;
  108. float2 val_f;
  109. __half2* val_h = reinterpret_cast<__half2*>(&val_f);
  110. val_h[0] = __float22half2_rn(inp1_h_f_0);
  111. val_h[1] = __float22half2_rn(inp1_h_f_1);
  112. float2* out_4 = reinterpret_cast<float2*>(out);
  113. out_4[j] = val_f;
  114. }
  115. }
  116. template <>
  117. void launch_fused_add2<float>(float* out,
  118. const float* inp1,
  119. const float* inp2,
  120. int batch_size,
  121. int seq_length,
  122. int hidden_dim,
  123. cudaStream_t& stream)
  124. {
  125. int total_count = batch_size * seq_length * hidden_dim / 4;
  126. dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
  127. dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
  128. fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
  129. }
  130. template <>
  131. void launch_fused_add2<__half>(__half* out,
  132. const __half* inp1,
  133. const __half* inp2,
  134. int batch_size,
  135. int seq_length,
  136. int hidden_dim,
  137. cudaStream_t& stream)
  138. {
  139. int total_count = batch_size * seq_length * hidden_dim / 4;
  140. dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
  141. dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
  142. fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
  143. }
  144. __global__ void fused_add3_kernel(float* out,
  145. const float* inp1,
  146. const float* inp2,
  147. const float* inp3,
  148. int size,
  149. int row_stride)
  150. {
  151. int row = blockIdx.x;
  152. int id = threadIdx.x;
  153. const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
  154. const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
  155. const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
  156. float4* out_4 = reinterpret_cast<float4*>(out);
  157. float4 val;
  158. float4 inp1_reg = inp1_4[row * row_stride + id];
  159. float4 inp2_reg = inp2_4[row * row_stride + id];
  160. float4 inp3_reg = inp3_4[row * row_stride + id];
  161. val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
  162. val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
  163. val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
  164. val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
  165. out_4[row * row_stride + id] = val;
  166. }
  167. __global__ void fused_add3_kernel(__half* out,
  168. const __half* inp1,
  169. const __half* inp2,
  170. const __half* inp3,
  171. int size,
  172. int row_stride)
  173. {
  174. int row = blockIdx.x;
  175. int id = threadIdx.x;
  176. const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
  177. const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
  178. const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
  179. float2 inp1_4 = inp1_arr[row * row_stride + id];
  180. float2 inp2_4 = inp2_arr[row * row_stride + id];
  181. float2 inp3_4 = inp3_arr[row * row_stride + id];
  182. __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
  183. __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
  184. __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
  185. float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
  186. float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
  187. float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
  188. float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
  189. float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
  190. float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
  191. inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
  192. inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
  193. inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
  194. inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
  195. float2 val_f;
  196. __half2* val_h = reinterpret_cast<__half2*>(&val_f);
  197. val_h[0] = __float22half2_rn(inp1_h_f_0);
  198. val_h[1] = __float22half2_rn(inp1_h_f_1);
  199. float2* out_4 = reinterpret_cast<float2*>(out);
  200. out_4[row * row_stride + id] = val_f;
  201. }
  202. template <>
  203. void launch_fused_add3<float>(float* out,
  204. const float* inp1,
  205. const float* inp2,
  206. const float* inp3,
  207. int batch_size,
  208. int seq_length,
  209. int hidden_size,
  210. cudaStream_t& stream)
  211. {
  212. dim3 grid_dim(batch_size * seq_length);
  213. dim3 block_dim(hidden_size / 4);
  214. fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
  215. out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
  216. }
  217. template <>
  218. void launch_fused_add3<__half>(__half* out,
  219. const __half* inp1,
  220. const __half* inp2,
  221. const __half* inp3,
  222. int batch_size,
  223. int seq_length,
  224. int hidden_size,
  225. cudaStream_t& stream)
  226. {
  227. dim3 grid_dim(batch_size * seq_length);
  228. dim3 block_dim(hidden_size / 4);
  229. fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
  230. out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
  231. }
  232. __global__ void fused_add4_kernel(float* out,
  233. const float* inp1,
  234. const float* inp2,
  235. const float* inp3,
  236. const float* inp4,
  237. int size,
  238. int row_stride)
  239. {
  240. int row = blockIdx.x;
  241. int id = threadIdx.x;
  242. const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
  243. const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
  244. const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
  245. const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
  246. float4* out_4 = reinterpret_cast<float4*>(out);
  247. float4 val;
  248. float4 inp1_reg = inp1_4[row * row_stride + id];
  249. float4 inp2_reg = inp2_4[row * row_stride + id];
  250. float4 inp3_reg = inp3_4[row * row_stride + id];
  251. float4 inp4_reg = inp4_4[row * row_stride + id];
  252. val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
  253. val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
  254. val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
  255. val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
  256. out_4[row * row_stride + id] = val;
  257. }
  258. __global__ void fused_add4_kernel(__half* out,
  259. const __half* inp1,
  260. const __half* inp2,
  261. const __half* inp3,
  262. const __half* inp4,
  263. int size,
  264. int row_stride)
  265. {
  266. int row = blockIdx.x;
  267. int id = threadIdx.x;
  268. const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
  269. const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
  270. const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
  271. const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
  272. float2 inp1_4 = inp1_arr[row * row_stride + id];
  273. float2 inp2_4 = inp2_arr[row * row_stride + id];
  274. float2 inp3_4 = inp3_arr[row * row_stride + id];
  275. float2 inp4_4 = inp4_arr[row * row_stride + id];
  276. __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
  277. __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
  278. __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
  279. __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
  280. float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
  281. float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
  282. float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
  283. float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
  284. float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
  285. float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
  286. float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
  287. float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
  288. inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
  289. inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
  290. inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
  291. inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
  292. float2 val_f;
  293. __half2* val_h = reinterpret_cast<__half2*>(&val_f);
  294. val_h[0] = __float22half2_rn(inp1_h_f_0);
  295. val_h[1] = __float22half2_rn(inp1_h_f_1);
  296. float2* out_4 = reinterpret_cast<float2*>(out);
  297. out_4[row * row_stride + id] = val_f;
  298. }
  299. template <>
  300. void launch_fused_add4<float>(float* out,
  301. const float* inp1,
  302. const float* inp2,
  303. const float* inp3,
  304. const float* inp4,
  305. int batch_size,
  306. int seq_length,
  307. int hidden_size,
  308. cudaStream_t& stream)
  309. {
  310. dim3 grid_dim(batch_size * seq_length);
  311. dim3 block_dim(hidden_size / 4);
  312. fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
  313. out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
  314. }
  315. template <>
  316. void launch_fused_add4<__half>(__half* out,
  317. const __half* inp1,
  318. const __half* inp2,
  319. const __half* inp3,
  320. const __half* inp4,
  321. int batch_size,
  322. int seq_length,
  323. int hidden_size,
  324. cudaStream_t& stream)
  325. {
  326. dim3 grid_dim(batch_size * seq_length);
  327. dim3 block_dim(hidden_size / 4);
  328. fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
  329. out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
  330. }