transform_kernels.cu 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. #define rows_trans 16
  6. #define cols_trans 16
  7. template <typename T>
  8. __global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
  9. {
  10. __shared__ T data_block[rows_trans * (cols_trans + 1)];
  11. int r = threadIdx.x / cols_trans;
  12. int c = threadIdx.x % cols_trans;
  13. int m = row_width / cols_trans;
  14. int i = blockIdx.x / m * rows_trans + r;
  15. int j = blockIdx.x % m * cols_trans + c;
  16. int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
  17. for (int k = 0; k < rows_trans; k += row_stride)
  18. data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
  19. __syncthreads();
  20. i = blockIdx.x % m * rows_trans + r;
  21. j = blockIdx.x / m * cols_trans + c;
  22. for (int k = 0; k < rows_trans; k += row_stride)
  23. out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
  24. }
  25. template <>
  26. void Transpose<__half>(const __half* inp_mat,
  27. __half* out_mat,
  28. int rows,
  29. int cols,
  30. cudaStream_t stream)
  31. {
  32. int threads = THREADS;
  33. Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
  34. inp_mat, out_mat, cols, rows);
  35. }
  36. template <>
  37. void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream)
  38. {
  39. int threads = THREADS;
  40. Transpose_Kernel<float><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
  41. inp_mat, out_mat, cols, rows);
  42. }
  43. template <typename T>
  44. __global__ void transform_0213(T* output,
  45. const T* vals,
  46. int hidden_dim,
  47. int seq_length,
  48. int heads,
  49. int head_ext);
  50. template <>
  51. __global__ void transform_0213<float>(float* output,
  52. const float* vals,
  53. int hidden_dim,
  54. int seq_length,
  55. int heads,
  56. int head_ext)
  57. {
  58. int d0_stride = hidden_dim * seq_length;
  59. int d1_stride = hidden_dim;
  60. int d2_stride = hidden_dim / heads;
  61. int d0_out_stride = d0_stride;
  62. int d1_out_stride = d2_stride;
  63. int d2_out_stride = d2_stride * seq_length;
  64. int d0 = blockIdx.x; // Batch
  65. int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
  66. int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
  67. int d3 = threadIdx.x; // Values (groups of 4)
  68. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  69. float4* output_vec = reinterpret_cast<float4*>(output);
  70. float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
  71. output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
  72. }
  73. template <>
  74. __global__ void transform_0213<__half>(__half* output,
  75. const __half* vals,
  76. int hidden_dim,
  77. int seq_length,
  78. int heads,
  79. int head_ext)
  80. {
  81. #ifdef HALF_PRECISION_AVAILABLE
  82. int d0_stride = hidden_dim * seq_length;
  83. int d1_stride = hidden_dim;
  84. int d2_stride = hidden_dim / heads;
  85. int d0_out_stride = d0_stride;
  86. int d1_out_stride = d2_stride;
  87. int d2_out_stride = d2_stride * seq_length;
  88. int d0 = blockIdx.x; // Batch
  89. int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
  90. int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
  91. int d3 = threadIdx.x; // Values (groups of 4)
  92. float4 vals_arr[1];
  93. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  94. float4* output_vec = reinterpret_cast<float4*>(output);
  95. vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
  96. output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
  97. #endif
  98. }
  99. template <>
  100. void launch_transform_0213<float>(float* output,
  101. const float* vals,
  102. int batch_size,
  103. int seq_length,
  104. int hidden_dim,
  105. int heads,
  106. cudaStream_t stream)
  107. {
  108. hidden_dim >>= 2;
  109. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  110. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  111. dim3 grid_dim(batch_size, (seq_length * head_ext));
  112. transform_0213<float>
  113. <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
  114. }
  115. template <>
  116. void launch_transform_0213<__half>(__half* output,
  117. const __half* vals,
  118. int batch_size,
  119. int seq_length,
  120. int hidden_dim,
  121. int heads,
  122. cudaStream_t stream)
  123. {
  124. hidden_dim >>= 3;
  125. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  126. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  127. dim3 grid_dim(batch_size, (seq_length * head_ext));
  128. transform_0213<__half>
  129. <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
  130. }
  131. // Bias add
  132. template <typename T>
  133. __global__ void bias_add_transform_0213(T* output,
  134. const T* vals,
  135. const T* bias,
  136. int hidden_dim,
  137. int seq_length,
  138. int heads,
  139. int head_ext);
  140. template <>
  141. __global__ void bias_add_transform_0213<float>(float* output,
  142. const float* vals,
  143. const float* bias,
  144. int hidden_dim,
  145. int seq_length,
  146. int heads,
  147. int head_ext)
  148. {
  149. int d0_stride = hidden_dim * seq_length;
  150. int d1_stride = hidden_dim;
  151. int d2_stride = hidden_dim / heads;
  152. int d0_out_stride = d0_stride;
  153. int d1_out_stride = d2_stride;
  154. int d2_out_stride = d2_stride * seq_length;
  155. int d0 = blockIdx.x; // Batch
  156. int d1 = blockIdx.y; // Sequence ID (0-127)
  157. int cnt = blockIdx.z / head_ext; // Hidden count
  158. int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
  159. int d3 = threadIdx.x; // Values (groups of 4)
  160. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  161. const float4* bias_vec = reinterpret_cast<const float4*>(bias);
  162. float4* output_vec = reinterpret_cast<float4*>(output);
  163. float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
  164. d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
  165. float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
  166. float4 outputs;
  167. outputs.x = inputs.x + biases.x;
  168. outputs.y = inputs.y + biases.y;
  169. outputs.z = inputs.z + biases.z;
  170. outputs.w = inputs.w + biases.w;
  171. output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
  172. d2 * d2_out_stride + d3] = outputs;
  173. }
  174. #define ATTN_H 3
  175. #define MAX_SEQ_LINE 10
  176. template <>
  177. __global__ void bias_add_transform_0213<__half>(__half* output,
  178. const __half* vals,
  179. const __half* bias,
  180. int hidden_dim,
  181. int seq_length,
  182. int heads,
  183. int head_ext)
  184. {
  185. #ifdef HALF_PRECISION_AVAILABLE
  186. int d0_stride = hidden_dim * seq_length;
  187. int d1_stride = hidden_dim;
  188. int d2_stride = hidden_dim / heads;
  189. int d2_out_stride = d2_stride * seq_length;
  190. int d0 = blockIdx.x; // Batch
  191. int d1 = blockIdx.y; // Sequence ID (0-127)
  192. int cnt = blockIdx.z / head_ext; // Hidden count
  193. int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
  194. int d3 = threadIdx.x; // Values (groups of 4)
  195. float4 vals_arr;
  196. float4 bias_arr;
  197. float4 output_arr;
  198. __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
  199. __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
  200. __half2* output_half = reinterpret_cast<__half2*>(&output_arr);
  201. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  202. const float4* bias_vec = reinterpret_cast<const float4*>(bias);
  203. float4* output_vec = reinterpret_cast<float4*>(output);
  204. vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
  205. vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
  206. vals_vec += (cnt * d1_stride);
  207. vals_vec += (d2 * d2_stride);
  208. bias_vec += (cnt * d1_stride);
  209. bias_vec += (d2 * d2_stride);
  210. output_vec += (cnt * d0_stride * gridDim.x);
  211. output_vec += (d1 * d2_stride);
  212. output_vec += (d0 * d0_stride);
  213. output_vec += (d2 * d2_out_stride);
  214. bias_arr = bias_vec[d3];
  215. vals_arr = vals_vec[d3];
  216. #if defined(__ACC_HALF__)
  217. output_half[0] = vals_half[0] + bias_half[0];
  218. output_half[1] = vals_half[1] + bias_half[1];
  219. output_half[2] = vals_half[2] + bias_half[2];
  220. output_half[3] = vals_half[3] + bias_half[3];
  221. #else
  222. float2 bias_arr_f[4];
  223. float2 vals_arr_f[4];
  224. #pragma unroll
  225. for (int l = 0; l < 4; l++) {
  226. bias_arr_f[l] = __half22float2(bias_half[l]);
  227. vals_arr_f[l] = __half22float2(vals_half[l]);
  228. vals_arr_f[l].x += bias_arr_f[l].x;
  229. vals_arr_f[l].y += bias_arr_f[l].y;
  230. output_half[l] = __float22half2_rn(vals_arr_f[l]);
  231. }
  232. #endif
  233. output_vec[d3] = output_arr;
  234. #endif
  235. }
  236. __global__ void bias_add_transform_0213_v2(__half* output,
  237. const __half* vals,
  238. const __half* bias,
  239. int hidden_dim,
  240. int seq_length,
  241. int heads)
  242. {
  243. #ifdef HALF_PRECISION_AVAILABLE
  244. __shared__ float4 in_data[3072];
  245. int d0_stride = hidden_dim * seq_length;
  246. int d1_stride = hidden_dim;
  247. int d2_stride = hidden_dim / heads;
  248. int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
  249. int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
  250. int d0_out_stride = d0_stride;
  251. int d1_out_stride = d2_stride;
  252. int d2_out_stride = d2_stride * seq_length;
  253. int d0 = blockIdx.x; // Batch
  254. int d1 = blockIdx.y; // Sequence ID (0-127)
  255. int cnt = threadIdx.z; // blockIdx.z; // Hidden count
  256. int d2 = threadIdx.y; // Head (0-11)
  257. int d3 = threadIdx.x; // Values (groups of 4)
  258. float4 vals_arr[1];
  259. float4 bias_arr[1];
  260. float4 output_arr[1];
  261. __half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
  262. __half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
  263. __half2* output_half = reinterpret_cast<__half2*>(output_arr);
  264. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  265. const float4* bias_vec = reinterpret_cast<const float4*>(bias);
  266. float4* output_vec = reinterpret_cast<float4*>(output);
  267. int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
  268. int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
  269. bias_arr[0] = bias_vec[iter_index];
  270. #pragma unroll
  271. for (int iter = 0; iter < 2; iter++) {
  272. int iter_id = iter * iteration_stride + iter_index;
  273. vals_arr[0] = vals_vec[input_offset + iter_id];
  274. output_half[0] = vals_half[0] + bias_half[0];
  275. output_half[1] = vals_half[1] + bias_half[1];
  276. output_half[2] = vals_half[2] + bias_half[2];
  277. output_half[3] = vals_half[3] + bias_half[3];
  278. in_data[iter_id] = output_arr[0];
  279. }
  280. __syncthreads();
  281. iteration_stride = blockDim.z * (blockDim.y >> 1);
  282. int matrix_stride = (d0_out_stride * gridDim.x);
  283. int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
  284. int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
  285. #pragma unroll
  286. for (int iter = 0; iter < 2; iter++) {
  287. int iter_row = (iter * iteration_stride) + head_count;
  288. int iter_offset =
  289. (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
  290. output_vec[out_index + iter_offset] =
  291. in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
  292. }
  293. #endif
  294. }
  295. // [B S C*H] - > C * [B A S N]
  296. template <>
  297. void launch_bias_add_transform_0213<float>(float* output,
  298. const float* vals,
  299. const float* bias,
  300. int batch_size,
  301. int seq_length,
  302. int hidden_dim,
  303. int heads,
  304. cudaStream_t stream,
  305. int trans_count)
  306. {
  307. hidden_dim >>= 2;
  308. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  309. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  310. dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
  311. bias_add_transform_0213<float><<<grid_dim, block_dim, 0, stream>>>(
  312. output, vals, bias, hidden_dim, seq_length, heads, head_ext);
  313. }
  314. template <>
  315. void launch_bias_add_transform_0213<__half>(__half* output,
  316. const __half* vals,
  317. const __half* bias,
  318. int batch_size,
  319. int seq_length,
  320. int hidden_dim,
  321. int heads,
  322. cudaStream_t stream,
  323. int trans_count)
  324. {
  325. hidden_dim >>= 3;
  326. if (hidden_dim > 128 || hidden_dim < 16) {
  327. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  328. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  329. dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
  330. bias_add_transform_0213<__half><<<grid_dim, block_dim, 0, stream>>>(
  331. output, vals, bias, hidden_dim, seq_length, heads, head_ext);
  332. } else {
  333. dim3 block_dim(hidden_dim / heads, heads, trans_count);
  334. dim3 grid_dim(batch_size, seq_length / 2);
  335. bias_add_transform_0213_v2<<<grid_dim, block_dim, 0, stream>>>(
  336. output, vals, bias, hidden_dim, seq_length, heads);
  337. }
  338. }
  339. template <typename T>
  340. __global__ void transform4d_0213(T* out,
  341. const T* in,
  342. int heads,
  343. int seq_length,
  344. int hidden_dim,
  345. int head_ext);
  346. template <>
  347. __global__ void transform4d_0213<float>(float* out,
  348. const float* in,
  349. int heads,
  350. int seq_length,
  351. int hidden_dim,
  352. int head_ext)
  353. {
  354. int d0_stride = hidden_dim * seq_length;
  355. int d1_stride = d0_stride / heads;
  356. int d2_stride = hidden_dim / heads;
  357. int d0_out_stride = d0_stride;
  358. int d1_out_stride = d2_stride;
  359. int d2_out_stride = hidden_dim;
  360. int d0 = blockIdx.x; // Batch
  361. int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
  362. int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
  363. int cnt = blockIdx.z;
  364. int d3 = threadIdx.x; // Values (groups of 8)
  365. if (d2 < seq_length) {
  366. const float4* in_vec = reinterpret_cast<const float4*>(in);
  367. float4* out_vec = reinterpret_cast<float4*>(out);
  368. float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
  369. d2 * d2_stride + d3];
  370. out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
  371. d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
  372. }
  373. }
  374. template <>
  375. __global__ void transform4d_0213<__half>(__half* out,
  376. const __half* in,
  377. int heads,
  378. int seq_length,
  379. int hidden_dim,
  380. int head_ext)
  381. {
  382. #ifdef HALF_PRECISION_AVAILABLE
  383. int d0_stride = hidden_dim * (seq_length / head_ext);
  384. int d1_stride = hidden_dim;
  385. int d2_stride = hidden_dim / heads;
  386. int d0 = blockIdx.x; // Batch
  387. int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
  388. int d2 = blockIdx.z / head_ext; // Sequence
  389. int cnt = blockIdx.y; // Hidden count
  390. int d3 = threadIdx.x; // Values (groups of 8)
  391. const float4* in_vec = reinterpret_cast<const float4*>(in);
  392. float4* out_vec = reinterpret_cast<float4*>(out);
  393. in_vec += (cnt * d0_stride * gridDim.x);
  394. in_vec += (d0 * d0_stride);
  395. in_vec += (d2 * d2_stride);
  396. in_vec += (d1 * d2_stride * seq_length);
  397. out_vec += (cnt * d1_stride);
  398. out_vec += (d1 * d2_stride);
  399. out_vec += (d0 * d0_stride * gridDim.y);
  400. out_vec += (d2 * d1_stride * gridDim.y);
  401. out_vec[d3] = in_vec[d3];
  402. #endif
  403. }
  404. __global__ void transform4d_0213_v2(__half* out,
  405. const __half* in,
  406. int heads,
  407. int seq_length,
  408. int hidden_dim)
  409. {
  410. #ifdef HALF_PRECISION_AVAILABLE
  411. __shared__ float4 in_data[3072];
  412. int d0_stride = hidden_dim * seq_length;
  413. int d1_stride = hidden_dim;
  414. int d2_stride = hidden_dim / heads;
  415. int d0 = blockIdx.x; // Batch
  416. int d1 = threadIdx.y; // Head
  417. int d2 = blockIdx.y; // Sequence
  418. int cnt = threadIdx.z; // Hidden count
  419. int d3 = threadIdx.x; // Values (groups of 8)
  420. const float4* in_vec = reinterpret_cast<const float4*>(in);
  421. float4* out_vec = reinterpret_cast<float4*>(out);
  422. int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
  423. int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
  424. int iteration_stride = blockDim.z * (blockDim.y >> 1);
  425. int matrix_stride = (d0_stride * gridDim.x);
  426. #pragma unroll
  427. for (int iter = 0; iter < 2; iter++) {
  428. int iter_row = iter * iteration_stride + head_count;
  429. int iter_offset = (iter_row % blockDim.y) * d2_stride;
  430. in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
  431. in_vec[input_offset + iter_offset * seq_length +
  432. (iter_row / blockDim.y) * matrix_stride];
  433. }
  434. __syncthreads();
  435. iteration_stride = d1_stride * blockDim.z;
  436. int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
  437. int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
  438. #pragma unroll
  439. for (int iter = 0; iter < 2; iter++) {
  440. int iter_id = iter * iteration_stride + iter_index;
  441. out_vec[output_offset + iter_id] = in_data[iter_id];
  442. }
  443. #endif
  444. }
  445. // 3 * [B A S N] - > [B S C*H]
  446. template <>
  447. void launch_transform4d_0213<float>(float* out,
  448. const float* in,
  449. int batch_size,
  450. int heads,
  451. int seq_length,
  452. int hidden_dim,
  453. cudaStream_t stream,
  454. int trans_count)
  455. {
  456. hidden_dim >>= 2;
  457. dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
  458. dim3 block_dims(hidden_dim / heads, 8);
  459. transform4d_0213<float>
  460. <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim, 1);
  461. }
  462. template <>
  463. void launch_transform4d_0213<__half>(__half* out,
  464. const __half* in,
  465. int batch_size,
  466. int heads,
  467. int seq_length,
  468. int hidden_dim,
  469. cudaStream_t stream,
  470. int trans_count)
  471. {
  472. hidden_dim >>= 3;
  473. if (hidden_dim > 128 || hidden_dim < 16) {
  474. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  475. dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
  476. dim3 block_dims(hidden_dim / heads, (heads / head_ext));
  477. transform4d_0213<__half><<<grid_dims, block_dims, 0, stream>>>(
  478. out, in, heads, seq_length, hidden_dim, head_ext);
  479. } else {
  480. dim3 grid_dims(batch_size, seq_length / 2);
  481. dim3 block_dims(hidden_dim / heads, heads, trans_count);
  482. transform4d_0213_v2<<<grid_dims, block_dims, 0, stream>>>(
  483. out, in, heads, seq_length, hidden_dim);
  484. }
  485. }