transform_kernels.cu 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. #include "custom_cuda_layers.h"
  2. #define rows_trans 16
  3. #define cols_trans 16
  4. template <typename T>
  5. __global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
  6. {
  7. __shared__ T data_block[rows_trans * (cols_trans + 1)];
  8. int r = threadIdx.x / cols_trans;
  9. int c = threadIdx.x % cols_trans;
  10. int m = row_width / cols_trans;
  11. int i = blockIdx.x / m * rows_trans + r;
  12. int j = blockIdx.x % m * cols_trans + c;
  13. int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
  14. for (int k = 0; k < rows_trans; k += row_stride)
  15. data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
  16. __syncthreads();
  17. i = blockIdx.x % m * rows_trans + r;
  18. j = blockIdx.x / m * cols_trans + c;
  19. for (int k = 0; k < rows_trans; k += row_stride)
  20. out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
  21. }
  22. template <>
  23. void Transpose<__half>(const __half* inp_mat,
  24. __half* out_mat,
  25. int rows,
  26. int cols,
  27. cudaStream_t stream)
  28. {
  29. int threads = THREADS;
  30. Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
  31. inp_mat, out_mat, cols, rows);
  32. }
  33. template <>
  34. void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream)
  35. {
  36. int threads = THREADS;
  37. Transpose_Kernel<float><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
  38. inp_mat, out_mat, cols, rows);
  39. }
  40. template <typename T>
  41. __global__ void transform_0213(T* output,
  42. const T* vals,
  43. int hidden_dim,
  44. int seq_length,
  45. int heads,
  46. int head_ext);
  47. template <>
  48. __global__ void transform_0213<float>(float* output,
  49. const float* vals,
  50. int hidden_dim,
  51. int seq_length,
  52. int heads,
  53. int head_ext)
  54. {
  55. int d0_stride = hidden_dim * seq_length;
  56. int d1_stride = hidden_dim;
  57. int d2_stride = hidden_dim / heads;
  58. int d0_out_stride = d0_stride;
  59. int d1_out_stride = d2_stride;
  60. int d2_out_stride = d2_stride * seq_length;
  61. int d0 = blockIdx.x; // Batch
  62. int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
  63. int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
  64. int d3 = threadIdx.x; // Values (groups of 4)
  65. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  66. float4* output_vec = reinterpret_cast<float4*>(output);
  67. float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
  68. output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
  69. }
  70. template <>
  71. __global__ void transform_0213<__half>(__half* output,
  72. const __half* vals,
  73. int hidden_dim,
  74. int seq_length,
  75. int heads,
  76. int head_ext)
  77. {
  78. #if __CUDA_ARCH__ >= 700
  79. int d0_stride = hidden_dim * seq_length;
  80. int d1_stride = hidden_dim;
  81. int d2_stride = hidden_dim / heads;
  82. int d0_out_stride = d0_stride;
  83. int d1_out_stride = d2_stride;
  84. int d2_out_stride = d2_stride * seq_length;
  85. int d0 = blockIdx.x; // Batch
  86. int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
  87. int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
  88. int d3 = threadIdx.x; // Values (groups of 4)
  89. float4 vals_arr[1];
  90. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  91. float4* output_vec = reinterpret_cast<float4*>(output);
  92. vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
  93. output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
  94. #endif
  95. }
  96. template <>
  97. void launch_transform_0213<float>(float* output,
  98. const float* vals,
  99. int batch_size,
  100. int seq_length,
  101. int hidden_dim,
  102. int heads,
  103. cudaStream_t stream)
  104. {
  105. hidden_dim >>= 2;
  106. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  107. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  108. dim3 grid_dim(batch_size, (seq_length * head_ext));
  109. transform_0213<float>
  110. <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
  111. }
  112. template <>
  113. void launch_transform_0213<__half>(__half* output,
  114. const __half* vals,
  115. int batch_size,
  116. int seq_length,
  117. int hidden_dim,
  118. int heads,
  119. cudaStream_t stream)
  120. {
  121. hidden_dim >>= 3;
  122. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  123. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  124. dim3 grid_dim(batch_size, (seq_length * head_ext));
  125. transform_0213<__half>
  126. <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
  127. }
  128. // Bias add
  129. template <typename T>
  130. __global__ void bias_add_transform_0213(T* output,
  131. const T* vals,
  132. const T* bias,
  133. int hidden_dim,
  134. int seq_length,
  135. int heads,
  136. int head_ext);
  137. template <>
  138. __global__ void bias_add_transform_0213<float>(float* output,
  139. const float* vals,
  140. const float* bias,
  141. int hidden_dim,
  142. int seq_length,
  143. int heads,
  144. int head_ext)
  145. {
  146. int d0_stride = hidden_dim * seq_length;
  147. int d1_stride = hidden_dim;
  148. int d2_stride = hidden_dim / heads;
  149. int d0_out_stride = d0_stride;
  150. int d1_out_stride = d2_stride;
  151. int d2_out_stride = d2_stride * seq_length;
  152. int d0 = blockIdx.x; // Batch
  153. int d1 = blockIdx.y; // Sequence ID (0-127)
  154. int cnt = blockIdx.z / head_ext; // Hidden count
  155. int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
  156. int d3 = threadIdx.x; // Values (groups of 4)
  157. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  158. const float4* bias_vec = reinterpret_cast<const float4*>(bias);
  159. float4* output_vec = reinterpret_cast<float4*>(output);
  160. float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
  161. d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
  162. float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
  163. float4 outputs;
  164. outputs.x = inputs.x + biases.x;
  165. outputs.y = inputs.y + biases.y;
  166. outputs.z = inputs.z + biases.z;
  167. outputs.w = inputs.w + biases.w;
  168. output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
  169. d2 * d2_out_stride + d3] = outputs;
  170. }
  171. #define ATTN_H 3
  172. #define MAX_SEQ_LINE 10
  173. template <>
  174. __global__ void bias_add_transform_0213<__half>(__half* output,
  175. const __half* vals,
  176. const __half* bias,
  177. int hidden_dim,
  178. int seq_length,
  179. int heads,
  180. int head_ext)
  181. {
  182. #if __CUDA_ARCH__ >= 700
  183. int d0_stride = hidden_dim * seq_length;
  184. int d1_stride = hidden_dim;
  185. int d2_stride = hidden_dim / heads;
  186. int d2_out_stride = d2_stride * seq_length;
  187. int d0 = blockIdx.x; // Batch
  188. int d1 = blockIdx.y; // Sequence ID (0-127)
  189. int cnt = blockIdx.z / head_ext; // Hidden count
  190. int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
  191. int d3 = threadIdx.x; // Values (groups of 4)
  192. float4 vals_arr;
  193. float4 bias_arr;
  194. float4 output_arr;
  195. __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
  196. __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
  197. __half2* output_half = reinterpret_cast<__half2*>(&output_arr);
  198. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  199. const float4* bias_vec = reinterpret_cast<const float4*>(bias);
  200. float4* output_vec = reinterpret_cast<float4*>(output);
  201. vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
  202. vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
  203. vals_vec += (cnt * d1_stride);
  204. vals_vec += (d2 * d2_stride);
  205. bias_vec += (cnt * d1_stride);
  206. bias_vec += (d2 * d2_stride);
  207. output_vec += (cnt * d0_stride * gridDim.x);
  208. output_vec += (d1 * d2_stride);
  209. output_vec += (d0 * d0_stride);
  210. output_vec += (d2 * d2_out_stride);
  211. bias_arr = bias_vec[d3];
  212. vals_arr = vals_vec[d3];
  213. #if defined(__ACC_HALF__)
  214. output_half[0] = vals_half[0] + bias_half[0];
  215. output_half[1] = vals_half[1] + bias_half[1];
  216. output_half[2] = vals_half[2] + bias_half[2];
  217. output_half[3] = vals_half[3] + bias_half[3];
  218. #else
  219. float2 bias_arr_f[4];
  220. float2 vals_arr_f[4];
  221. #pragma unroll
  222. for (int l = 0; l < 4; l++) {
  223. bias_arr_f[l] = __half22float2(bias_half[l]);
  224. vals_arr_f[l] = __half22float2(vals_half[l]);
  225. vals_arr_f[l].x += bias_arr_f[l].x;
  226. vals_arr_f[l].y += bias_arr_f[l].y;
  227. output_half[l] = __float22half2_rn(vals_arr_f[l]);
  228. }
  229. #endif
  230. output_vec[d3] = output_arr;
  231. #endif
  232. }
  233. __global__ void bias_add_transform_0213_v2(__half* output,
  234. const __half* vals,
  235. const __half* bias,
  236. int hidden_dim,
  237. int seq_length,
  238. int heads)
  239. {
  240. #if __CUDA_ARCH__ >= 700
  241. __shared__ float4 in_data[3072];
  242. int d0_stride = hidden_dim * seq_length;
  243. int d1_stride = hidden_dim;
  244. int d2_stride = hidden_dim / heads;
  245. int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
  246. int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
  247. int d0_out_stride = d0_stride;
  248. int d1_out_stride = d2_stride;
  249. int d2_out_stride = d2_stride * seq_length;
  250. int d0 = blockIdx.x; // Batch
  251. int d1 = blockIdx.y; // Sequence ID (0-127)
  252. int cnt = threadIdx.z; // blockIdx.z; // Hidden count
  253. int d2 = threadIdx.y; // Head (0-11)
  254. int d3 = threadIdx.x; // Values (groups of 4)
  255. float4 vals_arr[1];
  256. float4 bias_arr[1];
  257. float4 output_arr[1];
  258. __half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
  259. __half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
  260. __half2* output_half = reinterpret_cast<__half2*>(output_arr);
  261. const float4* vals_vec = reinterpret_cast<const float4*>(vals);
  262. const float4* bias_vec = reinterpret_cast<const float4*>(bias);
  263. float4* output_vec = reinterpret_cast<float4*>(output);
  264. int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
  265. int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
  266. bias_arr[0] = bias_vec[iter_index];
  267. #pragma unroll
  268. for (int iter = 0; iter < 2; iter++) {
  269. int iter_id = iter * iteration_stride + iter_index;
  270. vals_arr[0] = vals_vec[input_offset + iter_id];
  271. output_half[0] = vals_half[0] + bias_half[0];
  272. output_half[1] = vals_half[1] + bias_half[1];
  273. output_half[2] = vals_half[2] + bias_half[2];
  274. output_half[3] = vals_half[3] + bias_half[3];
  275. in_data[iter_id] = output_arr[0];
  276. }
  277. __syncthreads();
  278. iteration_stride = blockDim.z * (blockDim.y >> 1);
  279. int matrix_stride = (d0_out_stride * gridDim.x);
  280. int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
  281. int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
  282. #pragma unroll
  283. for (int iter = 0; iter < 2; iter++) {
  284. int iter_row = (iter * iteration_stride) + head_count;
  285. int iter_offset =
  286. (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
  287. output_vec[out_index + iter_offset] =
  288. in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
  289. }
  290. #endif
  291. }
  292. // [B S C*H] - > C * [B A S N]
  293. template <>
  294. void launch_bias_add_transform_0213<float>(float* output,
  295. const float* vals,
  296. const float* bias,
  297. int batch_size,
  298. int seq_length,
  299. int hidden_dim,
  300. int heads,
  301. cudaStream_t stream,
  302. int trans_count)
  303. {
  304. hidden_dim >>= 2;
  305. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  306. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  307. dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
  308. bias_add_transform_0213<float><<<grid_dim, block_dim, 0, stream>>>(
  309. output, vals, bias, hidden_dim, seq_length, heads, head_ext);
  310. }
  311. template <>
  312. void launch_bias_add_transform_0213<__half>(__half* output,
  313. const __half* vals,
  314. const __half* bias,
  315. int batch_size,
  316. int seq_length,
  317. int hidden_dim,
  318. int heads,
  319. cudaStream_t stream,
  320. int trans_count)
  321. {
  322. hidden_dim >>= 3;
  323. if (hidden_dim > 128 || hidden_dim < 16) {
  324. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  325. dim3 block_dim(hidden_dim / heads, (heads / head_ext));
  326. dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
  327. bias_add_transform_0213<__half><<<grid_dim, block_dim, 0, stream>>>(
  328. output, vals, bias, hidden_dim, seq_length, heads, head_ext);
  329. } else {
  330. dim3 block_dim(hidden_dim / heads, heads, trans_count);
  331. dim3 grid_dim(batch_size, seq_length / 2);
  332. bias_add_transform_0213_v2<<<grid_dim, block_dim, 0, stream>>>(
  333. output, vals, bias, hidden_dim, seq_length, heads);
  334. }
  335. }
  336. template <typename T>
  337. __global__ void transform4d_0213(T* out,
  338. const T* in,
  339. int heads,
  340. int seq_length,
  341. int hidden_dim,
  342. int head_ext);
  343. template <>
  344. __global__ void transform4d_0213<float>(float* out,
  345. const float* in,
  346. int heads,
  347. int seq_length,
  348. int hidden_dim,
  349. int head_ext)
  350. {
  351. int d0_stride = hidden_dim * seq_length;
  352. int d1_stride = d0_stride / heads;
  353. int d2_stride = hidden_dim / heads;
  354. int d0_out_stride = d0_stride;
  355. int d1_out_stride = d2_stride;
  356. int d2_out_stride = hidden_dim;
  357. int d0 = blockIdx.x; // Batch
  358. int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
  359. int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
  360. int cnt = blockIdx.z;
  361. int d3 = threadIdx.x; // Values (groups of 8)
  362. if (d2 < seq_length) {
  363. const float4* in_vec = reinterpret_cast<const float4*>(in);
  364. float4* out_vec = reinterpret_cast<float4*>(out);
  365. float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
  366. d2 * d2_stride + d3];
  367. out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
  368. d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
  369. }
  370. }
  371. template <>
  372. __global__ void transform4d_0213<__half>(__half* out,
  373. const __half* in,
  374. int heads,
  375. int seq_length,
  376. int hidden_dim,
  377. int head_ext)
  378. {
  379. #if __CUDA_ARCH__ >= 700
  380. int d0_stride = hidden_dim * (seq_length / head_ext);
  381. int d1_stride = hidden_dim;
  382. int d2_stride = hidden_dim / heads;
  383. int d0 = blockIdx.x; // Batch
  384. int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
  385. int d2 = blockIdx.z / head_ext; // Sequence
  386. int cnt = blockIdx.y; // Hidden count
  387. int d3 = threadIdx.x; // Values (groups of 8)
  388. const float4* in_vec = reinterpret_cast<const float4*>(in);
  389. float4* out_vec = reinterpret_cast<float4*>(out);
  390. in_vec += (cnt * d0_stride * gridDim.x);
  391. in_vec += (d0 * d0_stride);
  392. in_vec += (d2 * d2_stride);
  393. in_vec += (d1 * d2_stride * seq_length);
  394. out_vec += (cnt * d1_stride);
  395. out_vec += (d1 * d2_stride);
  396. out_vec += (d0 * d0_stride * gridDim.y);
  397. out_vec += (d2 * d1_stride * gridDim.y);
  398. out_vec[d3] = in_vec[d3];
  399. #endif
  400. }
  401. __global__ void transform4d_0213_v2(__half* out,
  402. const __half* in,
  403. int heads,
  404. int seq_length,
  405. int hidden_dim)
  406. {
  407. #if __CUDA_ARCH__ >= 700
  408. __shared__ float4 in_data[3072];
  409. int d0_stride = hidden_dim * seq_length;
  410. int d1_stride = hidden_dim;
  411. int d2_stride = hidden_dim / heads;
  412. int d0 = blockIdx.x; // Batch
  413. int d1 = threadIdx.y; // Head
  414. int d2 = blockIdx.y; // Sequence
  415. int cnt = threadIdx.z; // Hidden count
  416. int d3 = threadIdx.x; // Values (groups of 8)
  417. const float4* in_vec = reinterpret_cast<const float4*>(in);
  418. float4* out_vec = reinterpret_cast<float4*>(out);
  419. int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
  420. int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
  421. int iteration_stride = blockDim.z * (blockDim.y >> 1);
  422. int matrix_stride = (d0_stride * gridDim.x);
  423. #pragma unroll
  424. for (int iter = 0; iter < 2; iter++) {
  425. int iter_row = iter * iteration_stride + head_count;
  426. int iter_offset = (iter_row % blockDim.y) * d2_stride;
  427. in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
  428. in_vec[input_offset + iter_offset * seq_length +
  429. (iter_row / blockDim.y) * matrix_stride];
  430. }
  431. __syncthreads();
  432. iteration_stride = d1_stride * blockDim.z;
  433. int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
  434. int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
  435. #pragma unroll
  436. for (int iter = 0; iter < 2; iter++) {
  437. int iter_id = iter * iteration_stride + iter_index;
  438. out_vec[output_offset + iter_id] = in_data[iter_id];
  439. }
  440. #endif
  441. }
  442. // 3 * [B A S N] - > [B S C*H]
  443. template <>
  444. void launch_transform4d_0213<float>(float* out,
  445. const float* in,
  446. int batch_size,
  447. int heads,
  448. int seq_length,
  449. int hidden_dim,
  450. cudaStream_t stream,
  451. int trans_count)
  452. {
  453. hidden_dim >>= 2;
  454. dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
  455. dim3 block_dims(hidden_dim / heads, 8);
  456. transform4d_0213<float>
  457. <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim, 1);
  458. }
  459. template <>
  460. void launch_transform4d_0213<__half>(__half* out,
  461. const __half* in,
  462. int batch_size,
  463. int heads,
  464. int seq_length,
  465. int hidden_dim,
  466. cudaStream_t stream,
  467. int trans_count)
  468. {
  469. hidden_dim >>= 3;
  470. if (hidden_dim > 128 || hidden_dim < 16) {
  471. int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
  472. dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
  473. dim3 block_dims(hidden_dim / heads, (heads / head_ext));
  474. transform4d_0213<__half><<<grid_dims, block_dims, 0, stream>>>(
  475. out, in, heads, seq_length, hidden_dim, head_ext);
  476. } else {
  477. dim3 grid_dims(batch_size, seq_length / 2);
  478. dim3 block_dims(hidden_dim / heads, heads, trans_count);
  479. transform4d_0213_v2<<<grid_dims, block_dims, 0, stream>>>(
  480. out, in, heads, seq_length, hidden_dim);
  481. }
  482. }