reduction_utils.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. /*
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. */
  4. #pragma once
  5. #include "conversion_utils.h"
  6. #include "ds_kernel_utils.h"
  7. #include "memory_access_utils.h"
  8. namespace cg = cooperative_groups;
  9. namespace reduce {
  10. enum class ROpType {
  11. // Addition
  12. Add,
  13. // Maximum reduction
  14. Max,
  15. // Minimum reduction
  16. Min,
  17. };
  18. constexpr int max_threads = 1024;
  19. constexpr int max_warps = max_threads / hw_warp_size;
  20. /*
  21. High level API. The API takes in a set of operations and variables
  22. and performs that reduction operation on that variable. The reductions
  23. of each of the arguments are completely independent of each other (
  24. i.e., the val1-op1 combination has no impact on val2-op2).
  25. Example usage:
  26. ``` cpp
  27. float max_val;
  28. float min_val;
  29. reduce::block<rop::Max, rop::Min>(tb, warp, max_val, min_val);
  30. ```
  31. TODO(cmikeh2): In theory, we might be able to do this sequentially with
  32. device functions and rely on the assembler correctly behaving. My initial
  33. instinct is this won't work, but if it does it would reduce implementation
  34. cost significantly.
  35. TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic
  36. currently supports this (more incidentally than anything else). It is not
  37. uncommon in something like softmax or a fused attention kernel to map multiple
  38. reductions to a thread block, but each reduction itself is only scoped
  39. to part of the threads (i.e block size = 512, 128 threads per reduction).
  40. */
  41. template <ROpType Op, int warp_bound = max_warps>
  42. DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val);
  43. template <ROpType Op1, ROpType Op2, int warp_bound = max_warps>
  44. DS_D_INLINE void block(cg::thread_block& tb,
  45. cg::thread_block_tile<hw_warp_size>& warp,
  46. float& val1,
  47. float& val2);
  48. template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound = max_warps>
  49. DS_D_INLINE void block(cg::thread_block& tb,
  50. cg::thread_block_tile<hw_warp_size>& warp,
  51. float& val1,
  52. float& val2,
  53. float& val3);
  54. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound = max_warps>
  55. DS_D_INLINE void block(cg::thread_block& tb,
  56. cg::thread_block_tile<hw_warp_size>& warp,
  57. float& val1,
  58. float& val2,
  59. float& val3,
  60. float& val4);
  61. /*
  62. The partitioned block is a special case of the above where in the warps of a threadblock are
  63. partitioned into separate independent reductions. For example, I might have an 8 warp thread block
  64. in which each pair of warps is processing an independent piece of data. I would then reduce that
  65. data with the something like the following:
  66. ``` cpp
  67. float max_val;
  68. reduce::partitioned_block<rop::Max, 2>(tb, warp, max_val);
  69. ```
  70. After which, each pair of warps would have coherent data with each other. Note, this API will not
  71. provide correct results if the number of warps per partition is not a power of 2.
  72. */
  73. template <ROpType Op, int num_threads>
  74. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  75. cg::thread_block_tile<hw_warp_size>& warp,
  76. float& val);
  77. template <ROpType Op1, ROpType Op2, int num_threads>
  78. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  79. cg::thread_block_tile<hw_warp_size>& warp,
  80. float& val1,
  81. float& val2);
  82. template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
  83. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  84. cg::thread_block_tile<hw_warp_size>& warp,
  85. float& val1,
  86. float& val2,
  87. float& val3);
  88. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
  89. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  90. cg::thread_block_tile<hw_warp_size>& warp,
  91. float& val1,
  92. float& val2,
  93. float& val3,
  94. float& val4);
  95. /*
  96. Single element reduction primitives. Used inside serial collection
  97. loops.
  98. Example usage:
  99. using rop = reduce::OpType;
  100. float min = init<rop::Min>();
  101. for (int i = 0; i < 4; i++) {
  102. min = reduce::element<rop::Min>(min, data[i]);
  103. }
  104. */
  105. template <ROpType Op, typename T>
  106. DS_D_INLINE T element(const T lhs, const T rhs);
  107. template <ROpType OType, typename T = float>
  108. DS_D_INLINE T init();
  109. /********************** Internal reduction APIs **********************/
  110. /*
  111. Single element "reductions". TODO(cmikeh2): this sort of "op" concept
  112. should be refactored into its own implementation at some point. This interface
  113. may be easily expanded for new types/operations, but the typical reductions
  114. we need are covered with min/max/add on float.
  115. NOTE: there is no mean reduction because that relies on knowledge of how
  116. many values were already reduced into each scalar. Implementing this on top
  117. of reduce should be straightforward (can just wrap the sum reduction) and
  118. would be a good extension of the header.
  119. */
  120. /* Float element reduce implementations */
  121. template <>
  122. DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
  123. {
  124. return lhs + rhs;
  125. }
  126. template <>
  127. DS_D_INLINE float element<ROpType::Max>(const float lhs, const float rhs)
  128. {
  129. return fmaxf(lhs, rhs);
  130. }
  131. template <>
  132. DS_D_INLINE float element<ROpType::Min>(const float lhs, const float rhs)
  133. {
  134. return fminf(lhs, rhs);
  135. }
  136. /* __half element reduce implementation */
  137. template <>
  138. DS_D_INLINE __half element<ROpType::Add>(const __half lhs, const __half rhs)
  139. {
  140. return lhs + rhs;
  141. }
  142. template <>
  143. DS_D_INLINE __half element<ROpType::Max>(const __half lhs, const __half rhs)
  144. {
  145. #if __CUDA_ARCH__ >= 800
  146. // Intrinsic limited to Ampere + newer
  147. return __hmax(lhs, rhs);
  148. #else
  149. return (lhs > rhs) ? lhs : rhs;
  150. #endif
  151. }
  152. template <>
  153. DS_D_INLINE __half element<ROpType::Min>(const __half lhs, const __half rhs)
  154. {
  155. #if __CUDA_ARCH__ >= 800
  156. // Intrinsic limited to Ampere + newer
  157. return __hmin(lhs, rhs);
  158. #else
  159. return (lhs < rhs) ? lhs : rhs;
  160. #endif
  161. }
  162. /* __half2 element reduce implementation */
  163. template <>
  164. DS_D_INLINE __half2 element<ROpType::Add>(const __half2 lhs, const __half2 rhs)
  165. {
  166. return lhs + rhs;
  167. }
  168. template <>
  169. DS_D_INLINE __half2 element<ROpType::Max>(const __half2 lhs, const __half2 rhs)
  170. {
  171. #if __CUDA_ARCH__ >= 800
  172. return __hmax2(lhs, rhs);
  173. #else
  174. __half2 ret_val;
  175. ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
  176. ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
  177. return ret_val;
  178. #endif
  179. }
  180. template <>
  181. DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
  182. {
  183. #if __CUDA_ARCH__ >= 800
  184. return __hmin2(lhs, rhs);
  185. #else
  186. __half2 ret_val;
  187. ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x;
  188. ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y;
  189. return ret_val;
  190. #endif
  191. }
  192. /*
  193. Reduction initialization primitives
  194. */
  195. template <>
  196. DS_D_INLINE float init<ROpType::Add>()
  197. {
  198. return 0.0f;
  199. }
  200. template <>
  201. DS_D_INLINE float init<ROpType::Min>()
  202. {
  203. // Positive infinity
  204. return INFINITY;
  205. }
  206. template <>
  207. DS_D_INLINE float init<ROpType::Max>()
  208. {
  209. // Negative infinity
  210. return -INFINITY;
  211. }
  212. template <>
  213. DS_D_INLINE __half init<ROpType::Add>()
  214. {
  215. constexpr __half_raw zero = {0x0000};
  216. return __half(zero);
  217. }
  218. template <>
  219. DS_D_INLINE __half init<ROpType::Min>()
  220. {
  221. constexpr __half_raw inf = {0x7C00};
  222. return __half(inf);
  223. }
  224. template <>
  225. DS_D_INLINE __half init<ROpType::Max>()
  226. {
  227. constexpr __half_raw neg_inf = {0xFC00};
  228. return __half(neg_inf);
  229. }
  230. template <>
  231. DS_D_INLINE __half2 init<ROpType::Add>()
  232. {
  233. constexpr __half2_raw zero = {0x0000, 0x0000};
  234. return __half2(zero);
  235. }
  236. template <>
  237. DS_D_INLINE __half2 init<ROpType::Min>()
  238. {
  239. constexpr __half2_raw inf = {0x7C00, 0x7C00};
  240. return __half2(inf);
  241. }
  242. template <>
  243. DS_D_INLINE __half2 init<ROpType::Max>()
  244. {
  245. constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
  246. return __half2(neg_inf);
  247. }
  248. template <ROpType Op, typename T>
  249. DS_D_INLINE void init(T* data)
  250. {
  251. data[0] = init<Op, T>();
  252. }
  253. template <ROpType Op1, ROpType Op2, typename T>
  254. DS_D_INLINE void init(T* data)
  255. {
  256. data[0] = init<Op1, T>();
  257. data[1] = init<Op2, T>();
  258. }
  259. template <ROpType Op1, ROpType Op2, ROpType Op3, typename T>
  260. DS_D_INLINE void init(T* data)
  261. {
  262. data[0] = init<Op1, T>();
  263. data[1] = init<Op2, T>();
  264. data[2] = init<Op3, T>();
  265. }
  266. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, typename T>
  267. DS_D_INLINE void init(T* data)
  268. {
  269. data[0] = init<Op1, T>();
  270. data[1] = init<Op2, T>();
  271. data[2] = init<Op3, T>();
  272. data[3] = init<Op4, T>();
  273. }
  274. /*
  275. Warp reduction primitives
  276. `reduction_width` is an unsafe template parameter, that is that
  277. when using `reduction_width` < hw_warp_size the warp is partitioned
  278. into `hw_warp_size` / `reduction_width` groups of partial sums.
  279. If someone can figure out how to use variadic templates in a reasonable way
  280. here (fold is C++17 only and I don't think helps and recursion feels like
  281. huge overkill that harms readability) that would be wonderful.
  282. */
  283. template <ROpType Op, int reduce_width = hw_warp_size>
  284. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  285. {
  286. #pragma unroll
  287. for (int i = 1; i < reduce_width; i *= 2) {
  288. data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i));
  289. }
  290. }
  291. template <ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
  292. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  293. {
  294. #pragma unroll
  295. for (int i = 1; i < reduce_width; i *= 2) {
  296. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  297. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  298. }
  299. }
  300. template <ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
  301. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  302. {
  303. #pragma unroll
  304. for (int i = 1; i < reduce_width; i *= 2) {
  305. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  306. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  307. data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
  308. }
  309. }
  310. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int reduce_width = hw_warp_size>
  311. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
  312. {
  313. #pragma unroll
  314. for (int i = 1; i < reduce_width; i *= 2) {
  315. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  316. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  317. data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
  318. data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i));
  319. }
  320. }
  321. /*
  322. Implementation for primary block reduction that serves both `block` and
  323. `partitioned_block`.
  324. `local_warp_rank` refers to the warp's location within the partition, so
  325. for an unpartitioned threadblock this will be equivalent to
  326. `warp_arg.meta_group_rank()`.
  327. Similarly, the warp offset is the `local_warp_rank` of the warp with the
  328. lowest rank in the partition. In the case of an 8 warp block with a
  329. 4 warp reduction, this would map to [0, 0, 0, 0, 4, 4, 4, 4].
  330. Partition size is the number of warps per partition (equal to the thread
  331. block in the default case). This enables us to only perform the warp reduction
  332. when able to.
  333. */
  334. template <int total_warps, ROpType... Ops>
  335. DS_D_INLINE void _block(cg::thread_block& tb,
  336. cg::thread_block_tile<hw_warp_size>& warp_arg,
  337. float* data,
  338. int warp_offset)
  339. {
  340. constexpr int elems = sizeof...(Ops);
  341. // Separated for now in case this no longer is true
  342. constexpr int bytes = sizeof(float);
  343. // Unused when `partition_size == 1` or total_warps == 1
  344. __shared__ float reduce_buffer[max_warps * elems];
  345. // Always perform warp-scope reduction
  346. _warp<Ops...>(warp_arg, data);
  347. // If max_warps == 1 let's skip the runtime check
  348. if (warp_arg.meta_group_size() > 1 && total_warps != 1) {
  349. if (warp_arg.thread_rank() == 0) {
  350. #pragma unroll
  351. for (int i = 0; i < elems; i++) {
  352. mem_access::store_shared<bytes>(
  353. reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i);
  354. }
  355. }
  356. // Synchronization inside block-uniform conditional is safe
  357. tb.sync();
  358. if (warp_arg.meta_group_rank() == 0) {
  359. if (warp_arg.thread_rank() < warp_arg.meta_group_size()) {
  360. #pragma unroll
  361. for (int i = 0; i < elems; i++) {
  362. mem_access::load_shared<bytes>(
  363. data + i, reduce_buffer + elems * warp_arg.thread_rank() + i);
  364. }
  365. } else {
  366. init<Ops...>(data);
  367. }
  368. _warp<Ops..., total_warps>(warp_arg, data);
  369. #pragma unroll
  370. for (int i = 0; i < elems; i++) {
  371. mem_access::store_shared<bytes>(reduce_buffer + elems * warp_arg.thread_rank() + i,
  372. data + i);
  373. }
  374. }
  375. // Synchronization inside block-uniform conditional is safe
  376. tb.sync();
  377. #pragma unroll
  378. for (int i = 0; i < elems; i++) {
  379. mem_access::load_shared<bytes>(data + i,
  380. reduce_buffer + warp_arg.meta_group_rank() * elems + i);
  381. }
  382. }
  383. }
  384. /*
  385. Main API implementations. For the most part, they just convert the individual
  386. variables into arrays, which makes working with them easier with a single
  387. implementation. In theory, we could use the `_block` implementation as another
  388. option, but the nature of using a pointer is a little less safe and this allows
  389. us to obfuscate the details of the partitioned implementation.
  390. */
  391. template <ROpType Op, int warp_bound>
  392. DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
  393. {
  394. _block<warp_bound, Op>(tb, warp, &val, 0);
  395. }
  396. template <ROpType Op1, ROpType Op2, int warp_bound>
  397. DS_D_INLINE void block(cg::thread_block& tb,
  398. cg::thread_block_tile<hw_warp_size>& warp,
  399. float& val1,
  400. float& val2)
  401. {
  402. float data[2] = {val1, val2};
  403. _block<warp_bound, Op1, Op2>(tb, warp, data, 0);
  404. val1 = data[0];
  405. val2 = data[1];
  406. }
  407. template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound>
  408. DS_D_INLINE void block(cg::thread_block& tb,
  409. cg::thread_block_tile<hw_warp_size>& warp,
  410. float& val1,
  411. float& val2,
  412. float& val3)
  413. {
  414. float data[3] = {val1, val2, val3};
  415. _block<warp_bound, Op1, Op2, Op3>(tb, warp, data, 0);
  416. val1 = data[0];
  417. val2 = data[1];
  418. val3 = data[2];
  419. }
  420. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound>
  421. DS_D_INLINE void block(cg::thread_block& tb,
  422. cg::thread_block_tile<hw_warp_size>& warp,
  423. float& val1,
  424. float& val2,
  425. float& val3,
  426. float& val4)
  427. {
  428. float data[4] = {val1, val2, val3, val4};
  429. _block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data, 0);
  430. val1 = data[0];
  431. val2 = data[1];
  432. val3 = data[2];
  433. val4 = data[3];
  434. }
  435. /*
  436. Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order
  437. to shorten block scale reduction length.
  438. */
  439. template <ROpType Op, int num_threads>
  440. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  441. cg::thread_block_tile<hw_warp_size>& warp,
  442. float& val)
  443. {
  444. if (num_threads <= hw_warp_size) {
  445. _warp<Op, num_threads>(warp, &val);
  446. } else {
  447. constexpr int num_warps = num_threads / hw_warp_size;
  448. const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
  449. _block<num_warps, Op>(tb, warp, &val, warp_offset);
  450. }
  451. }
  452. template <ROpType Op1, ROpType Op2, int num_threads>
  453. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  454. cg::thread_block_tile<hw_warp_size>& warp,
  455. float& val1,
  456. float& val2)
  457. {
  458. float data[2] = {val1, val2};
  459. if (num_threads <= hw_warp_size) {
  460. _warp<Op1, Op2, num_threads>(warp, data);
  461. } else {
  462. constexpr int num_warps = num_threads / hw_warp_size;
  463. const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
  464. _block<num_warps, Op1, Op2>(tb, warp, data, warp_offset);
  465. }
  466. val1 = data[0];
  467. val2 = data[1];
  468. }
  469. template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
  470. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  471. cg::thread_block_tile<hw_warp_size>& warp,
  472. float& val1,
  473. float& val2,
  474. float& val3)
  475. {
  476. float data[3] = {val1, val2, val3};
  477. if (num_threads <= hw_warp_size) {
  478. _warp<Op1, Op2, Op3, num_threads>(warp, data);
  479. } else {
  480. constexpr int num_warps = num_threads / hw_warp_size;
  481. const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
  482. _block<num_warps, Op1, Op2, Op3>(tb, warp, data, warp_offset);
  483. }
  484. val1 = data[0];
  485. val2 = data[1];
  486. val3 = data[2];
  487. }
  488. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
  489. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  490. cg::thread_block_tile<hw_warp_size>& warp,
  491. float& val1,
  492. float& val2,
  493. float& val3,
  494. float& val4)
  495. {
  496. float data[4] = {val1, val2, val3, val4};
  497. if (num_threads <= hw_warp_size) {
  498. _warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
  499. } else {
  500. constexpr int num_warps = num_threads / hw_warp_size;
  501. const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
  502. _block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data, warp_offset);
  503. }
  504. val1 = data[0];
  505. val2 = data[1];
  506. val3 = data[2];
  507. val4 = data[3];
  508. }
  509. } // namespace reduce