softmax_kernels.cu 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. #include <math.h>
  2. #include "custom_cuda_layers.h"
  3. #include "general_kernels.h"
  4. namespace cg = cooperative_groups;
  5. dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads)
  6. {
  7. int seq_length4 = sequence_length / 4;
  8. int block_compute_size =
  9. (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
  10. // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
  11. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
  12. // The batch size is typically relatively small, while the sequence length could potentially be
  13. // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
  14. unsigned x = heads * sequence_length / block_compute_size;
  15. unsigned y = batch_size;
  16. return {x, y};
  17. }
  18. // Fused attention + softmax
  19. template <int tbSize, int blockStride, int tbSeq>
  20. __global__ void attn_softmax(float* vals,
  21. const float* attn_mask,
  22. int heads,
  23. int seq_length,
  24. int iterations)
  25. {
  26. __shared__ float partialSum[MAX_WARP_NUM];
  27. int warp_num = blockDim.x >> 5;
  28. int iteration_stride = blockDim.x;
  29. int block_width = blockStride * seq_length;
  30. cg::thread_block b = cg::this_thread_block();
  31. cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
  32. int batch = blockIdx.y;
  33. int row = blockIdx.x;
  34. int max_threads_in_sequence = std::max(seq_length, tbSeq);
  35. int seq_lane = threadIdx.x % max_threads_in_sequence;
  36. int data_offset = batch * (gridDim.x * block_width) + row * block_width +
  37. (threadIdx.x / max_threads_in_sequence) * seq_length;
  38. int mask_offset = batch * seq_length;
  39. int wid = threadIdx.x >> 5;
  40. int lane = threadIdx.x & 0x1f;
  41. float4* val_cast = reinterpret_cast<float4*>(vals);
  42. const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
  43. float4 data[MAX_THREAD_ITERATIONS];
  44. float max_val = minus_infinity;
  45. for (int i = 0; i < iterations; i++) {
  46. int data_id = i * iteration_stride + seq_lane;
  47. if (data_id < seq_length) {
  48. float4 mask = attn_mask_cast[mask_offset + data_id];
  49. data[i] = val_cast[data_offset + data_id];
  50. data[i].x += mask.x;
  51. data[i].y += mask.y;
  52. data[i].z += mask.z;
  53. data[i].w += mask.w;
  54. max_val = (data[i].x > max_val ? data[i].x : max_val);
  55. max_val = (data[i].y > max_val ? data[i].y : max_val);
  56. max_val = (data[i].z > max_val ? data[i].z : max_val);
  57. max_val = (data[i].w > max_val ? data[i].w : max_val);
  58. } else {
  59. data[i].x = minus_infinity;
  60. data[i].y = minus_infinity;
  61. data[i].z = minus_infinity;
  62. data[i].w = minus_infinity;
  63. }
  64. }
  65. for (int i = 1; i < tbSize; i *= 2) {
  66. auto temp = g.shfl_xor(max_val, i);
  67. max_val = (temp > max_val ? temp : max_val);
  68. }
  69. if (seq_length > tbSize) {
  70. if (lane == 0) partialSum[wid] = max_val;
  71. b.sync();
  72. if (lane < warp_num) max_val = partialSum[lane];
  73. #ifndef __STOCHASTIC_MODE__
  74. b.sync();
  75. #endif
  76. int iters = warp_num;
  77. if (seq_length < iteration_stride)
  78. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  79. for (int i = 1; i < iters; i *= 2) {
  80. auto temp = g.shfl_xor(max_val, i);
  81. max_val = (temp > max_val ? temp : max_val);
  82. }
  83. max_val = g.shfl(max_val, threadIdx.x / tbSize);
  84. }
  85. float sum = 0;
  86. for (int i = 0; i < iterations; i++) {
  87. data[i].x = __expf(data[i].x - max_val);
  88. data[i].y = __expf(data[i].y - max_val);
  89. data[i].z = __expf(data[i].z - max_val);
  90. data[i].w = __expf(data[i].w - max_val);
  91. sum += (data[i].x + data[i].y + data[i].z + data[i].w);
  92. }
  93. for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
  94. if (seq_length > tbSize) {
  95. if (lane == 0) partialSum[wid] = sum;
  96. b.sync();
  97. if (lane < warp_num) sum = partialSum[lane];
  98. #ifndef __STOCHASTIC_MODE__
  99. b.sync();
  100. #endif
  101. int iters = warp_num;
  102. if (seq_length < iteration_stride)
  103. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  104. for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
  105. sum = g.shfl(sum, threadIdx.x / tbSize);
  106. }
  107. sum += 1e-6;
  108. for (int i = 0; i < iterations; i++) {
  109. data[i].x /= sum;
  110. data[i].y /= sum;
  111. data[i].z /= sum;
  112. data[i].w /= sum;
  113. int data_id = i * iteration_stride + seq_lane;
  114. if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
  115. }
  116. }
  117. template <int tbSize, int blockStride, int tbSeq>
  118. __global__ void attn_softmax(__half* vals,
  119. const __half* attn_mask,
  120. int heads,
  121. int seq_length,
  122. int iterations)
  123. {
  124. #if __CUDA_ARCH__ >= 700
  125. __shared__ float partialSum[MAX_WARP_NUM];
  126. int warp_num = blockDim.x >> 5;
  127. int iteration_stride = blockDim.x;
  128. int block_width = blockStride * seq_length;
  129. cg::thread_block b = cg::this_thread_block();
  130. cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
  131. int batch = blockIdx.y;
  132. int row = blockIdx.x;
  133. int max_threads_in_sequence = std::max(seq_length, tbSeq);
  134. int seq_lane = threadIdx.x % max_threads_in_sequence;
  135. int data_offset = batch * (gridDim.x * block_width) + row * block_width +
  136. (threadIdx.x / max_threads_in_sequence) * seq_length;
  137. int mask_offset = batch * seq_length;
  138. int wid = threadIdx.x >> 5;
  139. int lane = threadIdx.x & 0x1f;
  140. float2* val_cast = reinterpret_cast<float2*>(vals);
  141. const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
  142. val_cast += data_offset;
  143. attn_mask_cast += mask_offset;
  144. float2 low_data[MAX_THREAD_ITERATIONS];
  145. float2 high_data[MAX_THREAD_ITERATIONS];
  146. float max_val = minus_infinity;
  147. for (int i = 0; i < iterations; i++) {
  148. int data_id = i * iteration_stride + seq_lane;
  149. if (data_id < seq_length) {
  150. float2 data = val_cast[data_id];
  151. float2 mask = attn_mask_cast[data_id];
  152. __half2* data_arr = reinterpret_cast<__half2*>(&data);
  153. __half2* mask_arr = reinterpret_cast<__half2*>(&mask);
  154. low_data[i] = __half22float2(data_arr[0]);
  155. high_data[i] = __half22float2(data_arr[1]);
  156. float2 low_mask = __half22float2(mask_arr[0]);
  157. float2 high_mask = __half22float2(mask_arr[1]);
  158. low_data[i].x += low_mask.x;
  159. low_data[i].y += low_mask.y;
  160. high_data[i].x += high_mask.x;
  161. high_data[i].y += high_mask.y;
  162. max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
  163. max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
  164. max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
  165. max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
  166. }
  167. }
  168. for (int i = 1; i < tbSize; i *= 2) {
  169. auto temp = g.shfl_xor(max_val, i);
  170. max_val = (temp > max_val ? temp : max_val);
  171. }
  172. if (seq_length > tbSize) {
  173. if (lane == 0) partialSum[wid] = max_val;
  174. b.sync();
  175. if (lane < warp_num) max_val = partialSum[lane];
  176. #ifndef __STOCHASTIC_MODE__
  177. b.sync();
  178. #endif
  179. int iters = warp_num;
  180. if (seq_length < iteration_stride)
  181. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  182. for (int i = 1; i < iters; i *= 2) {
  183. auto temp = g.shfl_xor(max_val, i);
  184. max_val = (temp > max_val ? temp : max_val);
  185. }
  186. max_val = g.shfl(max_val, threadIdx.x / tbSize);
  187. }
  188. float sum = 0;
  189. for (int i = 0; i < iterations; i++) {
  190. int data_id = i * iteration_stride + seq_lane;
  191. if (data_id < seq_length) {
  192. low_data[i].x = __expf(low_data[i].x - max_val);
  193. low_data[i].y = __expf(low_data[i].y - max_val);
  194. high_data[i].x = __expf(high_data[i].x - max_val);
  195. high_data[i].y = __expf(high_data[i].y - max_val);
  196. sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
  197. }
  198. }
  199. for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
  200. if (seq_length > tbSize) {
  201. if (lane == 0) partialSum[wid] = sum;
  202. b.sync();
  203. if (lane < warp_num) sum = partialSum[lane];
  204. #ifndef __STOCHASTIC_MODE__
  205. b.sync();
  206. #endif
  207. int iters = warp_num;
  208. if (seq_length < iteration_stride)
  209. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  210. for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
  211. sum = g.shfl(sum, threadIdx.x / tbSize);
  212. }
  213. sum += 1e-6;
  214. for (int i = 0; i < iterations; i++) {
  215. int data_id = i * iteration_stride + seq_lane;
  216. if (data_id < seq_length) {
  217. float2 result_f;
  218. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  219. low_data[i].x /= sum;
  220. low_data[i].y /= sum;
  221. high_data[i].x /= sum;
  222. high_data[i].y /= sum;
  223. result_h[0] = __float22half2_rn(low_data[i]);
  224. result_h[1] = __float22half2_rn(high_data[i]);
  225. val_cast[data_id] = result_f;
  226. }
  227. }
  228. #endif
  229. }
  230. template <typename T>
  231. void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
  232. template <>
  233. void launch_attn_softmax<float>(float* vals,
  234. const float* attn_mask,
  235. int batch_size,
  236. int heads,
  237. int sequence_length,
  238. cudaStream_t stream)
  239. {
  240. const int threads = 128;
  241. int seq_length4 = sequence_length / 4;
  242. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  243. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  244. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  245. subblock_max_workload * threads)
  246. : threads);
  247. int iterations =
  248. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  249. : MAX_THREAD_ITERATIONS);
  250. if (sequence_length <= 8)
  251. attn_softmax<2, (threads / 2), 2>
  252. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  253. else if (sequence_length <= 16)
  254. attn_softmax<4, (threads / 4), 4>
  255. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  256. else if (sequence_length <= 32)
  257. attn_softmax<8, (threads / 8), 8>
  258. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  259. else if (sequence_length <= 64)
  260. attn_softmax<16, (threads / 16), 16>
  261. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  262. else if (sequence_length <= 128)
  263. attn_softmax<32, (threads / 32), 32>
  264. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  265. else if (sequence_length <= 256)
  266. attn_softmax<32, (threads / 64), 64>
  267. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  268. else {
  269. const int threads = 256;
  270. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  271. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  272. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  273. subblock_max_workload * threads)
  274. : threads);
  275. iterations =
  276. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  277. : MAX_THREAD_ITERATIONS);
  278. if (sequence_length <= 512)
  279. attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
  280. vals, attn_mask, heads, seq_length4, iterations);
  281. else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
  282. attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
  283. vals, attn_mask, heads, seq_length4, iterations);
  284. else
  285. throw std::runtime_error(
  286. "Unsupport Seq_Length! Check the restriction of the max_threads and "
  287. "max_thread_iterations!");
  288. }
  289. }
  290. template <>
  291. void launch_attn_softmax<__half>(__half* vals,
  292. const __half* attn_mask,
  293. int batch_size,
  294. int heads,
  295. int sequence_length,
  296. cudaStream_t stream)
  297. {
  298. const int threads = 128;
  299. int seq_length4 = sequence_length / 4;
  300. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  301. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  302. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  303. subblock_max_workload * threads)
  304. : threads);
  305. int iterations =
  306. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  307. : MAX_THREAD_ITERATIONS);
  308. if (sequence_length <= 8)
  309. attn_softmax<2, (threads / 2), 2>
  310. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  311. else if (sequence_length <= 16)
  312. attn_softmax<4, (threads / 4), 4>
  313. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  314. else if (sequence_length <= 32)
  315. attn_softmax<8, (threads / 8), 8>
  316. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  317. else if (sequence_length <= 64)
  318. attn_softmax<16, (threads / 16), 16>
  319. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  320. else if (sequence_length <= 128)
  321. attn_softmax<32, (threads / 32), 32>
  322. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  323. else if (sequence_length <= 256)
  324. attn_softmax<32, (threads / 64), 64>
  325. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  326. else {
  327. const int threads = 256;
  328. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  329. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  330. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  331. subblock_max_workload * threads)
  332. : threads);
  333. iterations =
  334. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  335. : MAX_THREAD_ITERATIONS);
  336. if (sequence_length <= 512)
  337. attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
  338. vals, attn_mask, heads, seq_length4, iterations);
  339. else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
  340. attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
  341. vals, attn_mask, heads, seq_length4, iterations);
  342. else
  343. throw std::runtime_error(
  344. "Unsupport Seq_Length! Check the restriction of the max_threads and "
  345. "max_thread_iterations!");
  346. }
  347. }
  348. template <typename T, int tbSize, int blockStride>
  349. __global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
  350. {
  351. __shared__ float partialSum[MAX_WARP_NUM];
  352. int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32)
  353. int iteration_stride = blockDim.x;
  354. int block_width = blockStride * seq_length;
  355. int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
  356. ? (seq_length + iteration_stride - 1) / iteration_stride
  357. : MAX_THREAD_ITERATIONS);
  358. cg::thread_block b = cg::this_thread_block();
  359. cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
  360. int row = blockIdx.x;
  361. int id = threadIdx.x;
  362. int wid = id >> 5;
  363. int lane = id & 0x1f;
  364. T val_reg[MAX_THREAD_ITERATIONS];
  365. T soft_reg[MAX_THREAD_ITERATIONS];
  366. float grad_reg = 0.0f;
  367. #pragma unroll
  368. for (int i = 0; i < iterations; i++) {
  369. int data_id = i * iteration_stride + id;
  370. if (data_id < block_width) {
  371. val_reg[i] = out_grad[row * block_width + data_id];
  372. soft_reg[i] = soft_inp[row * block_width + data_id];
  373. grad_reg += ((float)val_reg[i] *
  374. (float)soft_reg[i]); // if done in half, the multiplication, we may lose
  375. // 2% of accuracy in computation!!
  376. }
  377. }
  378. for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
  379. if (seq_length > tbSize) {
  380. if (lane == 0) partialSum[wid] = grad_reg;
  381. b.sync();
  382. if (lane < warp_num) grad_reg = partialSum[lane];
  383. int iters = warp_num;
  384. if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
  385. for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
  386. grad_reg = g.shfl(grad_reg, id / tbSize);
  387. }
  388. for (int i = 0; i < iterations; i++) {
  389. int data_id = i * iteration_stride + id;
  390. if (data_id < block_width) {
  391. float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
  392. out_grad[row * block_width + data_id] = (T)temp;
  393. }
  394. }
  395. }
  396. template <typename T, int ITERATIONS>
  397. __global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
  398. const T* output,
  399. int softmax_length)
  400. {
  401. int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
  402. int offset = batch_idx * softmax_length + threadIdx.x;
  403. grad += offset;
  404. output += offset;
  405. T grad_reg[ITERATIONS];
  406. T output_reg[ITERATIONS];
  407. float sum = 0.0;
  408. #pragma unroll
  409. for (int i = 0; i < ITERATIONS; ++i) {
  410. int curr_idx = threadIdx.x + i * WARP_SIZE;
  411. if (curr_idx < softmax_length) {
  412. grad_reg[i] = grad[i * WARP_SIZE];
  413. output_reg[i] = output[i * WARP_SIZE];
  414. sum += (float)grad_reg[i] * (float)output_reg[i];
  415. }
  416. }
  417. cg::thread_block b = cg::this_thread_block();
  418. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  419. for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
  420. #pragma unroll
  421. for (int i = 0; i < ITERATIONS; ++i) {
  422. int curr_idx = threadIdx.x + i * WARP_SIZE;
  423. if (curr_idx < softmax_length)
  424. grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
  425. }
  426. }
  427. template <typename T>
  428. void launch_attn_softmax_backward_v2(T* out_grad,
  429. const T* soft_inp,
  430. int batch_size,
  431. int heads,
  432. int seq_length,
  433. cudaStream_t stream)
  434. {
  435. const int warps_per_block = 4;
  436. dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
  437. dim3 block_dim(WARP_SIZE, warps_per_block);
  438. if (seq_length <= 32)
  439. softmax_backward_kernel_v2<T, 1>
  440. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  441. else if (seq_length <= 64)
  442. softmax_backward_kernel_v2<T, 2>
  443. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  444. else if (seq_length <= 128)
  445. softmax_backward_kernel_v2<T, 4>
  446. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  447. else if (seq_length <= 256)
  448. softmax_backward_kernel_v2<T, 8>
  449. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  450. else if (seq_length <= 384)
  451. softmax_backward_kernel_v2<T, 12>
  452. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  453. else if (seq_length <= 512)
  454. softmax_backward_kernel_v2<T, 16>
  455. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  456. else if (seq_length <= 768)
  457. softmax_backward_kernel_v2<T, 24>
  458. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  459. else if (seq_length <= 1024)
  460. softmax_backward_kernel_v2<T, 32>
  461. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  462. else if (seq_length <= 2048)
  463. softmax_backward_kernel_v2<T, 64>
  464. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  465. else
  466. throw std::runtime_error(
  467. std::string("Special sequence length found in softmax backward, seq_length: ") +
  468. std::to_string(seq_length));
  469. }
  470. template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
  471. const __half* soft_inp,
  472. int batch_size,
  473. int heads,
  474. int seq_length,
  475. cudaStream_t stream);
  476. template void launch_attn_softmax_backward_v2<float>(float* out_grad,
  477. const float* soft_inp,
  478. int batch_size,
  479. int heads,
  480. int seq_length,
  481. cudaStream_t stream);