softmax_kernels.cu 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <math.h>
  5. #include "custom_cuda_layers.h"
  6. #include "general_kernels.h"
  7. namespace cg = cooperative_groups;
  8. dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads)
  9. {
  10. int seq_length4 = sequence_length / 4;
  11. int block_compute_size =
  12. (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
  13. // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
  14. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
  15. // The batch size is typically relatively small, while the sequence length could potentially be
  16. // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
  17. unsigned x = heads * sequence_length / block_compute_size;
  18. unsigned y = batch_size;
  19. return {x, y};
  20. }
  21. // Fused attention + softmax
  22. template <int tbSize, int blockStride, int tbSeq>
  23. __global__ void attn_softmax(float* vals,
  24. const float* attn_mask,
  25. int heads,
  26. int seq_length,
  27. int iterations)
  28. {
  29. __shared__ float partialSum[MAX_WARP_NUM];
  30. int warp_num = blockDim.x >> WARP_SIZE_BITS;
  31. int iteration_stride = blockDim.x;
  32. int block_width = blockStride * seq_length;
  33. cg::thread_block b = cg::this_thread_block();
  34. cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
  35. int batch = blockIdx.y;
  36. int row = blockIdx.x;
  37. int max_threads_in_sequence = std::max(seq_length, tbSeq);
  38. int seq_lane = threadIdx.x % max_threads_in_sequence;
  39. int data_offset = batch * (gridDim.x * block_width) + row * block_width +
  40. (threadIdx.x / max_threads_in_sequence) * seq_length;
  41. int mask_offset = batch * seq_length;
  42. int wid = threadIdx.x >> WARP_SIZE_BITS;
  43. int lane = threadIdx.x & 0x1f;
  44. float4* val_cast = reinterpret_cast<float4*>(vals);
  45. const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
  46. float4 data[MAX_THREAD_ITERATIONS];
  47. float max_val = minus_infinity;
  48. for (int i = 0; i < iterations; i++) {
  49. int data_id = i * iteration_stride + seq_lane;
  50. if (data_id < seq_length) {
  51. float4 mask = attn_mask_cast[mask_offset + data_id];
  52. data[i] = val_cast[data_offset + data_id];
  53. data[i].x += mask.x;
  54. data[i].y += mask.y;
  55. data[i].z += mask.z;
  56. data[i].w += mask.w;
  57. max_val = (data[i].x > max_val ? data[i].x : max_val);
  58. max_val = (data[i].y > max_val ? data[i].y : max_val);
  59. max_val = (data[i].z > max_val ? data[i].z : max_val);
  60. max_val = (data[i].w > max_val ? data[i].w : max_val);
  61. } else {
  62. data[i].x = minus_infinity;
  63. data[i].y = minus_infinity;
  64. data[i].z = minus_infinity;
  65. data[i].w = minus_infinity;
  66. }
  67. }
  68. for (int i = 1; i < tbSize; i *= 2) {
  69. auto temp = g.shfl_xor(max_val, i);
  70. max_val = (temp > max_val ? temp : max_val);
  71. }
  72. if (seq_length > tbSize) {
  73. if (lane == 0) partialSum[wid] = max_val;
  74. b.sync();
  75. if (lane < warp_num) max_val = partialSum[lane];
  76. #ifndef __STOCHASTIC_MODE__
  77. b.sync();
  78. #endif
  79. int iters = warp_num;
  80. if (seq_length < iteration_stride)
  81. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  82. for (int i = 1; i < iters; i *= 2) {
  83. auto temp = g.shfl_xor(max_val, i);
  84. max_val = (temp > max_val ? temp : max_val);
  85. }
  86. max_val = g.shfl(max_val, threadIdx.x / tbSize);
  87. }
  88. float sum = 0;
  89. for (int i = 0; i < iterations; i++) {
  90. data[i].x = __expf(data[i].x - max_val);
  91. data[i].y = __expf(data[i].y - max_val);
  92. data[i].z = __expf(data[i].z - max_val);
  93. data[i].w = __expf(data[i].w - max_val);
  94. sum += (data[i].x + data[i].y + data[i].z + data[i].w);
  95. }
  96. for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
  97. if (seq_length > tbSize) {
  98. if (lane == 0) partialSum[wid] = sum;
  99. b.sync();
  100. if (lane < warp_num) sum = partialSum[lane];
  101. #ifndef __STOCHASTIC_MODE__
  102. b.sync();
  103. #endif
  104. int iters = warp_num;
  105. if (seq_length < iteration_stride)
  106. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  107. for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
  108. sum = g.shfl(sum, threadIdx.x / tbSize);
  109. }
  110. sum += 1e-6;
  111. for (int i = 0; i < iterations; i++) {
  112. data[i].x /= sum;
  113. data[i].y /= sum;
  114. data[i].z /= sum;
  115. data[i].w /= sum;
  116. int data_id = i * iteration_stride + seq_lane;
  117. if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
  118. }
  119. }
  120. template <int tbSize, int blockStride, int tbSeq>
  121. __global__ void attn_softmax(__half* vals,
  122. const __half* attn_mask,
  123. int heads,
  124. int seq_length,
  125. int iterations)
  126. {
  127. #ifdef HALF_PRECISION_AVAILABLE
  128. __shared__ float partialSum[MAX_WARP_NUM];
  129. int warp_num = blockDim.x >> WARP_SIZE_BITS;
  130. int iteration_stride = blockDim.x;
  131. int block_width = blockStride * seq_length;
  132. cg::thread_block b = cg::this_thread_block();
  133. cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
  134. int batch = blockIdx.y;
  135. int row = blockIdx.x;
  136. int max_threads_in_sequence = std::max(seq_length, tbSeq);
  137. int seq_lane = threadIdx.x % max_threads_in_sequence;
  138. int data_offset = batch * (gridDim.x * block_width) + row * block_width +
  139. (threadIdx.x / max_threads_in_sequence) * seq_length;
  140. int mask_offset = batch * seq_length;
  141. int wid = threadIdx.x >> WARP_SIZE_BITS;
  142. int lane = threadIdx.x & 0x1f;
  143. float2* val_cast = reinterpret_cast<float2*>(vals);
  144. const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
  145. val_cast += data_offset;
  146. attn_mask_cast += mask_offset;
  147. float2 low_data[MAX_THREAD_ITERATIONS];
  148. float2 high_data[MAX_THREAD_ITERATIONS];
  149. float max_val = minus_infinity;
  150. for (int i = 0; i < iterations; i++) {
  151. int data_id = i * iteration_stride + seq_lane;
  152. if (data_id < seq_length) {
  153. float2 data = val_cast[data_id];
  154. float2 mask = attn_mask_cast[data_id];
  155. __half2* data_arr = reinterpret_cast<__half2*>(&data);
  156. __half2* mask_arr = reinterpret_cast<__half2*>(&mask);
  157. low_data[i] = __half22float2(data_arr[0]);
  158. high_data[i] = __half22float2(data_arr[1]);
  159. float2 low_mask = __half22float2(mask_arr[0]);
  160. float2 high_mask = __half22float2(mask_arr[1]);
  161. low_data[i].x += low_mask.x;
  162. low_data[i].y += low_mask.y;
  163. high_data[i].x += high_mask.x;
  164. high_data[i].y += high_mask.y;
  165. max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
  166. max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
  167. max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
  168. max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
  169. }
  170. }
  171. for (int i = 1; i < tbSize; i *= 2) {
  172. auto temp = g.shfl_xor(max_val, i);
  173. max_val = (temp > max_val ? temp : max_val);
  174. }
  175. if (seq_length > tbSize) {
  176. if (lane == 0) partialSum[wid] = max_val;
  177. b.sync();
  178. if (lane < warp_num) max_val = partialSum[lane];
  179. #ifndef __STOCHASTIC_MODE__
  180. b.sync();
  181. #endif
  182. int iters = warp_num;
  183. if (seq_length < iteration_stride)
  184. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  185. for (int i = 1; i < iters; i *= 2) {
  186. auto temp = g.shfl_xor(max_val, i);
  187. max_val = (temp > max_val ? temp : max_val);
  188. }
  189. max_val = g.shfl(max_val, threadIdx.x / tbSize);
  190. }
  191. float sum = 0;
  192. for (int i = 0; i < iterations; i++) {
  193. int data_id = i * iteration_stride + seq_lane;
  194. if (data_id < seq_length) {
  195. low_data[i].x = __expf(low_data[i].x - max_val);
  196. low_data[i].y = __expf(low_data[i].y - max_val);
  197. high_data[i].x = __expf(high_data[i].x - max_val);
  198. high_data[i].y = __expf(high_data[i].y - max_val);
  199. sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
  200. }
  201. }
  202. for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
  203. if (seq_length > tbSize) {
  204. if (lane == 0) partialSum[wid] = sum;
  205. b.sync();
  206. if (lane < warp_num) sum = partialSum[lane];
  207. #ifndef __STOCHASTIC_MODE__
  208. b.sync();
  209. #endif
  210. int iters = warp_num;
  211. if (seq_length < iteration_stride)
  212. iters = warp_num / (iteration_stride / max_threads_in_sequence);
  213. for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
  214. sum = g.shfl(sum, threadIdx.x / tbSize);
  215. }
  216. sum += 1e-6;
  217. for (int i = 0; i < iterations; i++) {
  218. int data_id = i * iteration_stride + seq_lane;
  219. if (data_id < seq_length) {
  220. float2 result_f;
  221. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  222. low_data[i].x /= sum;
  223. low_data[i].y /= sum;
  224. high_data[i].x /= sum;
  225. high_data[i].y /= sum;
  226. result_h[0] = __float22half2_rn(low_data[i]);
  227. result_h[1] = __float22half2_rn(high_data[i]);
  228. val_cast[data_id] = result_f;
  229. }
  230. }
  231. #endif
  232. }
  233. template <typename T>
  234. void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
  235. template <>
  236. void launch_attn_softmax<float>(float* vals,
  237. const float* attn_mask,
  238. int batch_size,
  239. int heads,
  240. int sequence_length,
  241. cudaStream_t stream)
  242. {
  243. const int threads = 128;
  244. int seq_length4 = sequence_length / 4;
  245. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  246. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  247. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  248. subblock_max_workload * threads)
  249. : threads);
  250. int iterations =
  251. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  252. : MAX_THREAD_ITERATIONS);
  253. if (sequence_length <= 8)
  254. attn_softmax<2, (threads / 2), 2>
  255. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  256. else if (sequence_length <= 16)
  257. attn_softmax<4, (threads / 4), 4>
  258. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  259. else if (sequence_length <= 32)
  260. attn_softmax<8, (threads / 8), 8>
  261. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  262. else if (sequence_length <= 64)
  263. attn_softmax<16, (threads / 16), 16>
  264. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  265. else if (sequence_length <= 128)
  266. attn_softmax<32, (threads / 32), 32>
  267. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  268. else if (sequence_length <= 256)
  269. attn_softmax<32, (threads / 64), 64>
  270. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  271. else {
  272. const int threads = 256;
  273. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  274. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  275. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  276. subblock_max_workload * threads)
  277. : threads);
  278. iterations =
  279. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  280. : MAX_THREAD_ITERATIONS);
  281. if (sequence_length <= 512)
  282. attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
  283. vals, attn_mask, heads, seq_length4, iterations);
  284. else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
  285. attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
  286. vals, attn_mask, heads, seq_length4, iterations);
  287. else
  288. throw std::runtime_error(
  289. "Unsupport Seq_Length! Check the restriction of the max_threads and "
  290. "max_thread_iterations!");
  291. }
  292. }
  293. template <>
  294. void launch_attn_softmax<__half>(__half* vals,
  295. const __half* attn_mask,
  296. int batch_size,
  297. int heads,
  298. int sequence_length,
  299. cudaStream_t stream)
  300. {
  301. const int threads = 128;
  302. int seq_length4 = sequence_length / 4;
  303. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  304. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  305. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  306. subblock_max_workload * threads)
  307. : threads);
  308. int iterations =
  309. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  310. : MAX_THREAD_ITERATIONS);
  311. if (sequence_length <= 8)
  312. attn_softmax<2, (threads / 2), 2>
  313. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  314. else if (sequence_length <= 16)
  315. attn_softmax<4, (threads / 4), 4>
  316. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  317. else if (sequence_length <= 32)
  318. attn_softmax<8, (threads / 8), 8>
  319. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  320. else if (sequence_length <= 64)
  321. attn_softmax<16, (threads / 16), 16>
  322. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  323. else if (sequence_length <= 128)
  324. attn_softmax<32, (threads / 32), 32>
  325. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  326. else if (sequence_length <= 256)
  327. attn_softmax<32, (threads / 64), 64>
  328. <<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
  329. else {
  330. const int threads = 256;
  331. dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
  332. int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
  333. dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
  334. subblock_max_workload * threads)
  335. : threads);
  336. iterations =
  337. (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
  338. : MAX_THREAD_ITERATIONS);
  339. if (sequence_length <= 512)
  340. attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
  341. vals, attn_mask, heads, seq_length4, iterations);
  342. else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
  343. attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
  344. vals, attn_mask, heads, seq_length4, iterations);
  345. else
  346. throw std::runtime_error(
  347. "Unsupport Seq_Length! Check the restriction of the max_threads and "
  348. "max_thread_iterations!");
  349. }
  350. }
  351. template <typename T, int tbSize, int blockStride>
  352. __global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
  353. {
  354. __shared__ float partialSum[MAX_WARP_NUM];
  355. int warp_num = blockDim.x >> WARP_SIZE_BITS; // warp-count = num_threads / WARP_SIZE (32)
  356. int iteration_stride = blockDim.x;
  357. int block_width = blockStride * seq_length;
  358. int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
  359. ? (seq_length + iteration_stride - 1) / iteration_stride
  360. : MAX_THREAD_ITERATIONS);
  361. cg::thread_block b = cg::this_thread_block();
  362. cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
  363. int row = blockIdx.x;
  364. int id = threadIdx.x;
  365. int wid = id >> WARP_SIZE_BITS;
  366. int lane = id & 0x1f;
  367. T val_reg[MAX_THREAD_ITERATIONS];
  368. T soft_reg[MAX_THREAD_ITERATIONS];
  369. float grad_reg = 0.0f;
  370. #pragma unroll
  371. for (int i = 0; i < iterations; i++) {
  372. int data_id = i * iteration_stride + id;
  373. if (data_id < block_width) {
  374. val_reg[i] = out_grad[row * block_width + data_id];
  375. soft_reg[i] = soft_inp[row * block_width + data_id];
  376. grad_reg += ((float)val_reg[i] *
  377. (float)soft_reg[i]); // if done in half, the multiplication, we may lose
  378. // 2% of accuracy in computation!!
  379. }
  380. }
  381. for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
  382. if (seq_length > tbSize) {
  383. if (lane == 0) partialSum[wid] = grad_reg;
  384. b.sync();
  385. if (lane < warp_num) grad_reg = partialSum[lane];
  386. int iters = warp_num;
  387. if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
  388. for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
  389. grad_reg = g.shfl(grad_reg, id / tbSize);
  390. }
  391. for (int i = 0; i < iterations; i++) {
  392. int data_id = i * iteration_stride + id;
  393. if (data_id < block_width) {
  394. float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
  395. out_grad[row * block_width + data_id] = (T)temp;
  396. }
  397. }
  398. }
  399. template <typename T, int ITERATIONS>
  400. __global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
  401. const T* output,
  402. int softmax_length)
  403. {
  404. int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
  405. int offset = batch_idx * softmax_length + threadIdx.x;
  406. grad += offset;
  407. output += offset;
  408. T grad_reg[ITERATIONS];
  409. T output_reg[ITERATIONS];
  410. float sum = 0.0;
  411. #pragma unroll
  412. for (int i = 0; i < ITERATIONS; ++i) {
  413. int curr_idx = threadIdx.x + i * WARP_SIZE;
  414. if (curr_idx < softmax_length) {
  415. grad_reg[i] = grad[i * WARP_SIZE];
  416. output_reg[i] = output[i * WARP_SIZE];
  417. sum += (float)grad_reg[i] * (float)output_reg[i];
  418. }
  419. }
  420. cg::thread_block b = cg::this_thread_block();
  421. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  422. for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
  423. #pragma unroll
  424. for (int i = 0; i < ITERATIONS; ++i) {
  425. int curr_idx = threadIdx.x + i * WARP_SIZE;
  426. if (curr_idx < softmax_length)
  427. grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
  428. }
  429. }
  430. __global__ void softmax_backward_kernel_arbitrary_length(__half* grad /* input & output*/,
  431. const __half* output,
  432. int softmax_length)
  433. {
  434. int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
  435. int offset = batch_idx * softmax_length + threadIdx.x;
  436. const float4* output_cast = reinterpret_cast<const float4*>(output);
  437. float4* grad_cast = reinterpret_cast<float4*>(grad);
  438. grad_cast += offset;
  439. output_cast += offset;
  440. float sum = 0.0;
  441. int curr_idx = threadIdx.x;
  442. while (curr_idx < softmax_length) {
  443. float4 out_reg = output_cast[curr_idx];
  444. float4 grad_reg = grad_cast[curr_idx];
  445. __half2* out_h = reinterpret_cast<__half2*>(&out_reg);
  446. __half2* grad_h = reinterpret_cast<__half2*>(&grad_reg);
  447. #pragma unroll
  448. for (int m = 0; m < 4; m++) grad_h[m] *= out_h[m];
  449. sum += ((float)grad_h[0].x + (float)grad_h[0].y + (float)grad_h[1].x + (float)grad_h[1].y) +
  450. ((float)grad_h[2].x + (float)grad_h[2].y + (float)grad_h[3].x + (float)grad_h[3].y);
  451. curr_idx += WARP_SIZE;
  452. }
  453. cg::thread_block b = cg::this_thread_block();
  454. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  455. #pragma unroll
  456. for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
  457. curr_idx = threadIdx.x;
  458. while (curr_idx < softmax_length) {
  459. float4 out_reg = output_cast[curr_idx];
  460. float4 grad_reg = grad_cast[curr_idx];
  461. __half* grad_h = reinterpret_cast<__half*>(&grad_reg);
  462. __half* out_h = reinterpret_cast<__half*>(&out_reg);
  463. #pragma unroll
  464. for (int m = 0; m < 8; m++) grad_h[m] = (float)out_h[m] * ((float)grad_h[m] - sum);
  465. grad_cast[curr_idx] = grad_reg;
  466. curr_idx += WARP_SIZE;
  467. }
  468. }
  469. __global__ void softmax_backward_kernel_arbitrary_length(float* grad /* input & output*/,
  470. const float* output,
  471. int softmax_length)
  472. {
  473. int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
  474. int offset = batch_idx * softmax_length + threadIdx.x;
  475. const float4* output_cast = reinterpret_cast<const float4*>(output);
  476. float4* grad_cast = reinterpret_cast<float4*>(grad);
  477. grad_cast += offset;
  478. output_cast += offset;
  479. float sum = 0.0;
  480. int curr_idx = threadIdx.x;
  481. while (curr_idx < softmax_length) {
  482. float4 out_reg = output_cast[curr_idx];
  483. float4 grad_reg = grad_cast[curr_idx];
  484. grad_reg.x *= out_reg.x;
  485. grad_reg.y *= out_reg.y;
  486. grad_reg.z *= out_reg.z;
  487. grad_reg.w *= out_reg.w;
  488. sum += (grad_reg.x + grad_reg.y + grad_reg.z + grad_reg.w);
  489. curr_idx += WARP_SIZE;
  490. }
  491. cg::thread_block b = cg::this_thread_block();
  492. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  493. #pragma unroll
  494. for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
  495. curr_idx = threadIdx.x;
  496. while (curr_idx < softmax_length) {
  497. float4 out_reg = output_cast[curr_idx];
  498. float4 grad_reg = grad_cast[curr_idx];
  499. grad_reg.x = out_reg.x * (grad_reg.x - sum);
  500. grad_reg.y = out_reg.y * (grad_reg.y - sum);
  501. grad_reg.z = out_reg.z * (grad_reg.z - sum);
  502. grad_reg.w = out_reg.w * (grad_reg.w - sum);
  503. grad_cast[curr_idx] = grad_reg;
  504. curr_idx += WARP_SIZE;
  505. }
  506. }
  507. template <typename T>
  508. void launch_attn_softmax_backward_v2(T* out_grad,
  509. const T* soft_inp,
  510. int batch_size,
  511. int heads,
  512. int seq_length,
  513. cudaStream_t stream)
  514. {
  515. const int warps_per_block = 4;
  516. dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
  517. dim3 block_dim(WARP_SIZE, warps_per_block);
  518. if (seq_length <= 32)
  519. softmax_backward_kernel_v2<T, 1>
  520. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  521. else if (seq_length <= 64)
  522. softmax_backward_kernel_v2<T, 2>
  523. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  524. else if (seq_length <= 128)
  525. softmax_backward_kernel_v2<T, 4>
  526. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  527. else if (seq_length <= 256)
  528. softmax_backward_kernel_v2<T, 8>
  529. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  530. else if (seq_length <= 384)
  531. softmax_backward_kernel_v2<T, 12>
  532. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  533. else if (seq_length <= 512)
  534. softmax_backward_kernel_v2<T, 16>
  535. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  536. else if (seq_length <= 768)
  537. softmax_backward_kernel_v2<T, 24>
  538. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  539. else if (seq_length <= 1024)
  540. softmax_backward_kernel_v2<T, 32>
  541. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  542. else if (seq_length <= 2048)
  543. softmax_backward_kernel_v2<T, 64>
  544. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  545. else if (seq_length <= 4096)
  546. softmax_backward_kernel_v2<T, 128>
  547. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  548. else if (seq_length <= 8192)
  549. softmax_backward_kernel_v2<T, 256>
  550. <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
  551. else
  552. softmax_backward_kernel_arbitrary_length<<<grid_dim, block_dim, 0, stream>>>(
  553. out_grad, soft_inp, seq_length / (4 << ((sizeof(T) & 2) >> 1)));
  554. }
  555. template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
  556. const __half* soft_inp,
  557. int batch_size,
  558. int heads,
  559. int seq_length,
  560. cudaStream_t stream);
  561. template void launch_attn_softmax_backward_v2<float>(float* out_grad,
  562. const float* soft_inp,
  563. int batch_size,
  564. int heads,
  565. int seq_length,
  566. cudaStream_t stream);