reduction_utils.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include "conversion_utils.h"
  6. #include "ds_kernel_utils.h"
  7. #include "memory_access_utils.h"
  8. namespace cg = cooperative_groups;
  9. namespace reduce {
  10. enum class ROpType {
  11. // Addition
  12. Add,
  13. // Maximum reduction
  14. Max,
  15. // Minimum reduction
  16. Min,
  17. };
  18. constexpr int max_threads = 1024;
  19. constexpr int max_warps = max_threads / hw_warp_size;
  20. /*
  21. High level API. The API takes in a set of operations and variables
  22. and performs that reduction operation on that variable. The reductions
  23. of each of the arguments are completely independent of each other (
  24. i.e., the val1-op1 combination has no impact on val2-op2).
  25. Example usage:
  26. ``` cpp
  27. float max_val;
  28. float min_val;
  29. reduce::block<rop::Max, rop::Min>(tb, warp, max_val, min_val);
  30. ```
  31. TODO(cmikeh2): In theory, we might be able to do this sequentially with
  32. device functions and rely on the assembler correctly behaving. My initial
  33. instinct is this won't work, but if it does it would reduce implementation
  34. cost significantly.
  35. TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic
  36. currently supports this (more incidentally than anything else). It is not
  37. uncommon in something like softmax or a fused attention kernel to map multiple
  38. reductions to a thread block, but each reduction itself is only scoped
  39. to part of the threads (i.e block size = 512, 128 threads per reduction).
  40. */
  41. template <ROpType Op, int warp_bound = max_warps>
  42. DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val);
  43. template <ROpType Op1, ROpType Op2, int warp_bound = max_warps>
  44. DS_D_INLINE void block(cg::thread_block& tb,
  45. cg::thread_block_tile<hw_warp_size>& warp,
  46. float& val1,
  47. float& val2);
  48. template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound = max_warps>
  49. DS_D_INLINE void block(cg::thread_block& tb,
  50. cg::thread_block_tile<hw_warp_size>& warp,
  51. float& val1,
  52. float& val2,
  53. float& val3);
  54. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound = max_warps>
  55. DS_D_INLINE void block(cg::thread_block& tb,
  56. cg::thread_block_tile<hw_warp_size>& warp,
  57. float& val1,
  58. float& val2,
  59. float& val3,
  60. float& val4);
  61. /*
  62. The partitioned block is a special case of the above where in the warps of a threadblock are
  63. partitioned into separate independent reductions. For example, I might have an 8 warp thread block
  64. in which each pair of warps is processing an independent piece of data. I would then reduce that
  65. data with the something like the following:
  66. ``` cpp
  67. float max_val;
  68. reduce::partitioned_block<rop::Max, 2>(tb, warp, max_val);
  69. ```
  70. After which, each pair of warps would have coherent data with each other. Note, this API will not
  71. provide correct results if the number of warps per partition is not a power of 2.
  72. */
  73. template <ROpType Op, int num_threads>
  74. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  75. cg::thread_block_tile<hw_warp_size>& warp,
  76. float& val);
  77. template <ROpType Op1, ROpType Op2, int num_threads>
  78. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  79. cg::thread_block_tile<hw_warp_size>& warp,
  80. float& val1,
  81. float& val2);
  82. template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
  83. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  84. cg::thread_block_tile<hw_warp_size>& warp,
  85. float& val1,
  86. float& val2,
  87. float& val3);
  88. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
  89. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  90. cg::thread_block_tile<hw_warp_size>& warp,
  91. float& val1,
  92. float& val2,
  93. float& val3,
  94. float& val4);
  95. /*
  96. Single element reduction primitives. Used inside serial collection
  97. loops.
  98. Example usage:
  99. using rop = reduce::OpType;
  100. float min = init<rop::Min>();
  101. for (int i = 0; i < 4; i++) {
  102. min = reduce::element<rop::Min>(min, data[i]);
  103. }
  104. */
  105. template <ROpType Op, typename T>
  106. DS_D_INLINE T element(const T lhs, const T rhs);
  107. template <ROpType OType, typename T = float>
  108. DS_D_INLINE T init();
  109. /********************** Internal reduction APIs **********************/
  110. /*
  111. Single element "reductions". TODO(cmikeh2): this sort of "op" concept
  112. should be refactored into its own implementation at some point. This interface
  113. may be easily expanded for new types/operations, but the typical reductions
  114. we need are covered with min/max/add on float.
  115. NOTE: there is no mean reduction because that relies on knowledge of how
  116. many values were already reduced into each scalar. Implementing this on top
  117. of reduce should be straightforward (can just wrap the sum reduction) and
  118. would be a good extension of the header.
  119. */
  120. DS_D_INLINE int _warp_rank()
  121. {
  122. const int thread_rank =
  123. threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;
  124. return thread_rank / hw_warp_size;
  125. }
  126. /* Float element reduce implementations */
  127. template <>
  128. DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
  129. {
  130. return lhs + rhs;
  131. }
  132. template <>
  133. DS_D_INLINE float element<ROpType::Max>(const float lhs, const float rhs)
  134. {
  135. return fmaxf(lhs, rhs);
  136. }
  137. template <>
  138. DS_D_INLINE float element<ROpType::Min>(const float lhs, const float rhs)
  139. {
  140. return fminf(lhs, rhs);
  141. }
  142. /* __half element reduce implementation */
  143. template <>
  144. DS_D_INLINE __half element<ROpType::Add>(const __half lhs, const __half rhs)
  145. {
  146. return lhs + rhs;
  147. }
  148. template <>
  149. DS_D_INLINE __half element<ROpType::Max>(const __half lhs, const __half rhs)
  150. {
  151. #if __CUDA_ARCH__ >= 800
  152. // Intrinsic limited to Ampere + newer
  153. return __hmax(lhs, rhs);
  154. #else
  155. return (lhs > rhs) ? lhs : rhs;
  156. #endif
  157. }
  158. template <>
  159. DS_D_INLINE __half element<ROpType::Min>(const __half lhs, const __half rhs)
  160. {
  161. #if __CUDA_ARCH__ >= 800
  162. // Intrinsic limited to Ampere + newer
  163. return __hmin(lhs, rhs);
  164. #else
  165. return (lhs < rhs) ? lhs : rhs;
  166. #endif
  167. }
  168. /* __half2 element reduce implementation */
  169. template <>
  170. DS_D_INLINE __half2 element<ROpType::Add>(const __half2 lhs, const __half2 rhs)
  171. {
  172. return lhs + rhs;
  173. }
  174. template <>
  175. DS_D_INLINE __half2 element<ROpType::Max>(const __half2 lhs, const __half2 rhs)
  176. {
  177. #if __CUDA_ARCH__ >= 800
  178. return __hmax2(lhs, rhs);
  179. #else
  180. __half2 ret_val;
  181. ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
  182. ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
  183. return ret_val;
  184. #endif
  185. }
  186. template <>
  187. DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
  188. {
  189. #if __CUDA_ARCH__ >= 800
  190. return __hmin2(lhs, rhs);
  191. #else
  192. __half2 ret_val;
  193. ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x;
  194. ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y;
  195. return ret_val;
  196. #endif
  197. }
  198. /*
  199. Reduction initialization primitives
  200. */
  201. template <>
  202. DS_D_INLINE float init<ROpType::Add>()
  203. {
  204. return 0.0f;
  205. }
  206. template <>
  207. DS_D_INLINE float init<ROpType::Min>()
  208. {
  209. // Positive infinity
  210. return INFINITY;
  211. }
  212. template <>
  213. DS_D_INLINE float init<ROpType::Max>()
  214. {
  215. // Negative infinity
  216. return -INFINITY;
  217. }
  218. template <>
  219. DS_D_INLINE __half init<ROpType::Add>()
  220. {
  221. constexpr __half_raw zero = {0x0000};
  222. return __half(zero);
  223. }
  224. template <>
  225. DS_D_INLINE __half init<ROpType::Min>()
  226. {
  227. constexpr __half_raw inf = {0x7C00};
  228. return __half(inf);
  229. }
  230. template <>
  231. DS_D_INLINE __half init<ROpType::Max>()
  232. {
  233. constexpr __half_raw neg_inf = {0xFC00};
  234. return __half(neg_inf);
  235. }
  236. template <>
  237. DS_D_INLINE __half2 init<ROpType::Add>()
  238. {
  239. #ifdef __HIP_PLATFORM_HCC__
  240. return __half2{_Float16_2{0x0000, 0x0000}};
  241. #else
  242. constexpr __half2_raw zero = {0x0000, 0x0000};
  243. return __half2(zero);
  244. #endif
  245. }
  246. template <>
  247. DS_D_INLINE __half2 init<ROpType::Min>()
  248. {
  249. #ifdef __HIP_PLATFORM_HCC__
  250. return __half2{_Float16_2{0x7C00, 0x7C00}};
  251. #else
  252. constexpr __half2_raw inf = {0x7C00, 0x7C00};
  253. return __half2(inf);
  254. #endif
  255. }
  256. template <>
  257. DS_D_INLINE __half2 init<ROpType::Max>()
  258. {
  259. #ifdef __HIP_PLATFORM_HCC__
  260. return __half2{_Float16_2{0xFC00, 0xFC00}};
  261. #else
  262. constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
  263. return __half2(neg_inf);
  264. #endif
  265. }
  266. template <ROpType Op, typename T>
  267. DS_D_INLINE void init(T* data)
  268. {
  269. data[0] = init<Op, T>();
  270. }
  271. template <ROpType Op1, ROpType Op2, typename T>
  272. DS_D_INLINE void init(T* data)
  273. {
  274. data[0] = init<Op1, T>();
  275. data[1] = init<Op2, T>();
  276. }
  277. template <ROpType Op1, ROpType Op2, ROpType Op3, typename T>
  278. DS_D_INLINE void init(T* data)
  279. {
  280. data[0] = init<Op1, T>();
  281. data[1] = init<Op2, T>();
  282. data[2] = init<Op3, T>();
  283. }
  284. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, typename T>
  285. DS_D_INLINE void init(T* data)
  286. {
  287. data[0] = init<Op1, T>();
  288. data[1] = init<Op2, T>();
  289. data[2] = init<Op3, T>();
  290. data[3] = init<Op4, T>();
  291. }
  292. /*
  293. Warp reduction primitives
  294. `reduction_width` is an unsafe template parameter, that is that
  295. when using `reduction_width` < hw_warp_size the warp is partitioned
  296. into `hw_warp_size` / `reduction_width` groups of partial sums.
  297. If someone can figure out how to use variadic templates in a reasonable way
  298. here (fold is C++17 only and I don't think helps and recursion feels like
  299. huge overkill that harms readability) that would be wonderful.
  300. */
  301. template <ROpType Op, int reduce_width = hw_warp_size>
  302. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  303. {
  304. #pragma unroll
  305. for (int i = 1; i < reduce_width; i *= 2) {
  306. data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i));
  307. }
  308. }
  309. template <ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
  310. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  311. {
  312. #pragma unroll
  313. for (int i = 1; i < reduce_width; i *= 2) {
  314. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  315. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  316. }
  317. }
  318. template <ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
  319. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  320. {
  321. #pragma unroll
  322. for (int i = 1; i < reduce_width; i *= 2) {
  323. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  324. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  325. data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
  326. }
  327. }
  328. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int reduce_width = hw_warp_size>
  329. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  330. {
  331. #pragma unroll
  332. for (int i = 1; i < reduce_width; i *= 2) {
  333. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  334. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  335. data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
  336. data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i));
  337. }
  338. }
  339. /*
  340. Implementation for primary block reduction that serves both `block` and
  341. `partitioned_block`.
  342. Total warps refers to the reduction width of the reduction, not
  343. the number of warps in the block (which may exceed that
  344. if the block is partitioned or if we do a conservative bound at
  345. compile time).
  346. */
  347. template <int total_warps, ROpType... Ops>
  348. DS_D_INLINE void _block(cg::thread_block& tb,
  349. cg::thread_block_tile<hw_warp_size>& warp_arg,
  350. float* data)
  351. {
  352. constexpr int elems = sizeof...(Ops);
  353. // Separated for now in case this no longer is true
  354. constexpr int bytes = sizeof(float);
  355. // Unused when `partition_size == 1` or total_warps == 1
  356. __shared__ float reduce_buffer[max_warps * elems];
  357. #ifdef __HIP_PLATFORM_HCC__
  358. const int total_threads = blockDim.x * blockDim.y * blockDim.z;
  359. const int running_warps = total_threads / hw_warp_size;
  360. #else
  361. const int running_warps = warp_arg.meta_group_size();
  362. #endif
  363. // Always perform warp-scope reduction
  364. _warp<Ops...>(warp_arg, data);
  365. // If max_warps == 1 let's skip the runtime check
  366. if (total_warps != 1) {
  367. if (warp_arg.thread_rank() == 0) {
  368. #pragma unroll
  369. for (int i = 0; i < elems; i++) {
  370. mem_access::store_shared<bytes>(reduce_buffer + elems * _warp_rank() + i, data + i);
  371. }
  372. }
  373. // Synchronization inside block-uniform conditional is safe
  374. tb.sync();
  375. if (_warp_rank() == 0) {
  376. if (warp_arg.thread_rank() < running_warps) {
  377. #pragma unroll
  378. for (int i = 0; i < elems; i++) {
  379. mem_access::load_shared<bytes>(
  380. data + i, reduce_buffer + elems * warp_arg.thread_rank() + i);
  381. }
  382. } else {
  383. init<Ops...>(data);
  384. }
  385. _warp<Ops..., total_warps>(warp_arg, data);
  386. #pragma unroll
  387. for (int i = 0; i < elems; i++) {
  388. mem_access::store_shared<bytes>(reduce_buffer + elems * warp_arg.thread_rank() + i,
  389. data + i);
  390. }
  391. }
  392. // Synchronization inside block-uniform conditional is safe
  393. tb.sync();
  394. #pragma unroll
  395. for (int i = 0; i < elems; i++) {
  396. mem_access::load_shared<bytes>(data + i, reduce_buffer + _warp_rank() * elems + i);
  397. }
  398. }
  399. }
  400. /*
  401. Main API implementations. For the most part, they just convert the individual
  402. variables into arrays, which makes working with them easier with a single
  403. implementation. In theory, we could use the `_block` implementation as another
  404. option, but the nature of using a pointer is a little less safe and this allows
  405. us to obfuscate the details of the partitioned implementation.
  406. */
  407. template <ROpType Op, int warp_bound>
  408. DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
  409. {
  410. _block<warp_bound, Op>(tb, warp, &val);
  411. }
  412. template <ROpType Op1, ROpType Op2, int warp_bound>
  413. DS_D_INLINE void block(cg::thread_block& tb,
  414. cg::thread_block_tile<hw_warp_size>& warp,
  415. float& val1,
  416. float& val2)
  417. {
  418. float data[2] = {val1, val2};
  419. _block<warp_bound, Op1, Op2>(tb, warp, data);
  420. val1 = data[0];
  421. val2 = data[1];
  422. }
  423. template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound>
  424. DS_D_INLINE void block(cg::thread_block& tb,
  425. cg::thread_block_tile<hw_warp_size>& warp,
  426. float& val1,
  427. float& val2,
  428. float& val3)
  429. {
  430. float data[3] = {val1, val2, val3};
  431. _block<warp_bound, Op1, Op2, Op3>(tb, warp, data);
  432. val1 = data[0];
  433. val2 = data[1];
  434. val3 = data[2];
  435. }
  436. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound>
  437. DS_D_INLINE void block(cg::thread_block& tb,
  438. cg::thread_block_tile<hw_warp_size>& warp,
  439. float& val1,
  440. float& val2,
  441. float& val3,
  442. float& val4)
  443. {
  444. float data[4] = {val1, val2, val3, val4};
  445. _block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
  446. val1 = data[0];
  447. val2 = data[1];
  448. val3 = data[2];
  449. val4 = data[3];
  450. }
  451. /*
  452. Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order
  453. to shorten block scale reduction length.
  454. */
  455. template <ROpType Op, int num_threads>
  456. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  457. cg::thread_block_tile<hw_warp_size>& warp,
  458. float& val)
  459. {
  460. if (num_threads <= hw_warp_size) {
  461. _warp<Op, num_threads>(warp, &val);
  462. } else {
  463. constexpr int num_warps = num_threads / hw_warp_size;
  464. _block<num_warps, Op>(tb, warp, &val);
  465. }
  466. }
  467. template <ROpType Op1, ROpType Op2, int num_threads>
  468. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  469. cg::thread_block_tile<hw_warp_size>& warp,
  470. float& val1,
  471. float& val2)
  472. {
  473. float data[2] = {val1, val2};
  474. if (num_threads <= hw_warp_size) {
  475. _warp<Op1, Op2, num_threads>(warp, data);
  476. } else {
  477. constexpr int num_warps = num_threads / hw_warp_size;
  478. _block<num_warps, Op1, Op2>(tb, warp, data);
  479. }
  480. val1 = data[0];
  481. val2 = data[1];
  482. }
  483. template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
  484. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  485. cg::thread_block_tile<hw_warp_size>& warp,
  486. float& val1,
  487. float& val2,
  488. float& val3)
  489. {
  490. float data[3] = {val1, val2, val3};
  491. if (num_threads <= hw_warp_size) {
  492. _warp<Op1, Op2, Op3, num_threads>(warp, data);
  493. } else {
  494. constexpr int num_warps = num_threads / hw_warp_size;
  495. _block<num_warps, Op1, Op2, Op3>(tb, warp, data);
  496. }
  497. val1 = data[0];
  498. val2 = data[1];
  499. val3 = data[2];
  500. }
  501. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
  502. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  503. cg::thread_block_tile<hw_warp_size>& warp,
  504. float& val1,
  505. float& val2,
  506. float& val3,
  507. float& val4)
  508. {
  509. float data[4] = {val1, val2, val3, val4};
  510. if (num_threads <= hw_warp_size) {
  511. _warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
  512. } else {
  513. constexpr int num_warps = num_threads / hw_warp_size;
  514. _block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
  515. }
  516. val1 = data[0];
  517. val2 = data[1];
  518. val3 = data[2];
  519. val4 = data[3];
  520. }
  521. } // namespace reduce