reduction_utils.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include "conversion_utils.h"
  6. #include "ds_kernel_utils.h"
  7. #include "memory_access_utils.h"
  8. namespace cg = cooperative_groups;
  9. namespace reduce {
  10. enum class ROpType {
  11. // Addition
  12. Add,
  13. // Maximum reduction
  14. Max,
  15. // Minimum reduction
  16. Min,
  17. };
  18. constexpr int max_threads = 1024;
  19. constexpr int max_warps = max_threads / hw_warp_size;
  20. /*
  21. High level API. The API takes in a set of operations and variables
  22. and performs that reduction operation on that variable. The reductions
  23. of each of the arguments are completely independent of each other (
  24. i.e., the val1-op1 combination has no impact on val2-op2).
  25. Example usage:
  26. ``` cpp
  27. float max_val;
  28. float min_val;
  29. reduce::block<rop::Max, rop::Min>(tb, warp, max_val, min_val);
  30. ```
  31. TODO(cmikeh2): In theory, we might be able to do this sequentially with
  32. device functions and rely on the assembler correctly behaving. My initial
  33. instinct is this won't work, but if it does it would reduce implementation
  34. cost significantly.
  35. TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic
  36. currently supports this (more incidentally than anything else). It is not
  37. uncommon in something like softmax or a fused attention kernel to map multiple
  38. reductions to a thread block, but each reduction itself is only scoped
  39. to part of the threads (i.e block size = 512, 128 threads per reduction).
  40. */
  41. template <ROpType Op, int warp_bound = max_warps>
  42. DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val);
  43. template <ROpType Op1, ROpType Op2, int warp_bound = max_warps>
  44. DS_D_INLINE void block(cg::thread_block& tb,
  45. cg::thread_block_tile<hw_warp_size>& warp,
  46. float& val1,
  47. float& val2);
  48. template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound = max_warps>
  49. DS_D_INLINE void block(cg::thread_block& tb,
  50. cg::thread_block_tile<hw_warp_size>& warp,
  51. float& val1,
  52. float& val2,
  53. float& val3);
  54. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound = max_warps>
  55. DS_D_INLINE void block(cg::thread_block& tb,
  56. cg::thread_block_tile<hw_warp_size>& warp,
  57. float& val1,
  58. float& val2,
  59. float& val3,
  60. float& val4);
  61. /*
  62. The partitioned block is a special case of the above where in the warps of a threadblock are
  63. partitioned into separate independent reductions. For example, I might have an 8 warp thread block
  64. in which each pair of warps is processing an independent piece of data. I would then reduce that
  65. data with the something like the following:
  66. ``` cpp
  67. float max_val;
  68. reduce::partitioned_block<rop::Max, 2>(tb, warp, max_val);
  69. ```
  70. After which, each pair of warps would have coherent data with each other. Note, this API will not
  71. provide correct results if the number of warps per partition is not a power of 2.
  72. */
  73. template <ROpType Op, int num_threads>
  74. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  75. cg::thread_block_tile<hw_warp_size>& warp,
  76. float& val);
  77. template <ROpType Op1, ROpType Op2, int num_threads>
  78. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  79. cg::thread_block_tile<hw_warp_size>& warp,
  80. float& val1,
  81. float& val2);
  82. template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
  83. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  84. cg::thread_block_tile<hw_warp_size>& warp,
  85. float& val1,
  86. float& val2,
  87. float& val3);
  88. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
  89. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  90. cg::thread_block_tile<hw_warp_size>& warp,
  91. float& val1,
  92. float& val2,
  93. float& val3,
  94. float& val4);
  95. /*
  96. Single element reduction primitives. Used inside serial collection
  97. loops.
  98. Example usage:
  99. using rop = reduce::OpType;
  100. float min = init<rop::Min>();
  101. for (int i = 0; i < 4; i++) {
  102. min = reduce::element<rop::Min>(min, data[i]);
  103. }
  104. */
  105. template <ROpType Op, typename T>
  106. DS_D_INLINE T element(const T lhs, const T rhs);
  107. template <ROpType OType, typename T = float>
  108. DS_D_INLINE T init();
  109. /********************** Internal reduction APIs **********************/
  110. /*
  111. Single element "reductions". TODO(cmikeh2): this sort of "op" concept
  112. should be refactored into its own implementation at some point. This interface
  113. may be easily expanded for new types/operations, but the typical reductions
  114. we need are covered with min/max/add on float.
  115. NOTE: there is no mean reduction because that relies on knowledge of how
  116. many values were already reduced into each scalar. Implementing this on top
  117. of reduce should be straightforward (can just wrap the sum reduction) and
  118. would be a good extension of the header.
  119. */
  120. DS_D_INLINE int _warp_rank()
  121. {
  122. const int thread_rank =
  123. threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;
  124. return thread_rank / hw_warp_size;
  125. }
  126. /* Float element reduce implementations */
  127. template <>
  128. DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
  129. {
  130. return lhs + rhs;
  131. }
  132. template <>
  133. DS_D_INLINE double element<ROpType::Add>(const double lhs, const double rhs)
  134. {
  135. return lhs + rhs;
  136. }
  137. template <>
  138. DS_D_INLINE float element<ROpType::Max>(const float lhs, const float rhs)
  139. {
  140. return fmaxf(lhs, rhs);
  141. }
  142. template <>
  143. DS_D_INLINE float element<ROpType::Min>(const float lhs, const float rhs)
  144. {
  145. return fminf(lhs, rhs);
  146. }
  147. /* __half element reduce implementation */
  148. template <>
  149. DS_D_INLINE __half element<ROpType::Add>(const __half lhs, const __half rhs)
  150. {
  151. return lhs + rhs;
  152. }
  153. template <>
  154. DS_D_INLINE __half element<ROpType::Max>(const __half lhs, const __half rhs)
  155. {
  156. #if __CUDA_ARCH__ >= 800
  157. // Intrinsic limited to Ampere + newer
  158. return __hmax(lhs, rhs);
  159. #else
  160. return (lhs > rhs) ? lhs : rhs;
  161. #endif
  162. }
  163. #ifdef BF16_AVAILABLE
  164. template <>
  165. DS_D_INLINE __nv_bfloat16 element<ROpType::Max>(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs)
  166. {
  167. #if __CUDA_ARCH__ >= 800
  168. // Intrinsic limited to Ampere + newer
  169. return __hmax(lhs, rhs);
  170. #else
  171. return (lhs > rhs) ? lhs : rhs;
  172. #endif
  173. }
  174. #endif
  175. template <>
  176. DS_D_INLINE __half element<ROpType::Min>(const __half lhs, const __half rhs)
  177. {
  178. #if __CUDA_ARCH__ >= 800
  179. // Intrinsic limited to Ampere + newer
  180. return __hmin(lhs, rhs);
  181. #else
  182. return (lhs < rhs) ? lhs : rhs;
  183. #endif
  184. }
  185. /* __half2 element reduce implementation */
  186. template <>
  187. DS_D_INLINE __half2 element<ROpType::Add>(const __half2 lhs, const __half2 rhs)
  188. {
  189. return lhs + rhs;
  190. }
  191. template <>
  192. DS_D_INLINE __half2 element<ROpType::Max>(const __half2 lhs, const __half2 rhs)
  193. {
  194. #if __CUDA_ARCH__ >= 800
  195. return __hmax2(lhs, rhs);
  196. #else
  197. __half2 ret_val;
  198. ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
  199. ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
  200. return ret_val;
  201. #endif
  202. }
  203. #ifdef BF16_AVAILABLE
  204. template <>
  205. DS_D_INLINE __nv_bfloat162 element<ROpType::Max>(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs)
  206. {
  207. #if __CUDA_ARCH__ >= 800
  208. return __hmax2(lhs, rhs);
  209. #else
  210. __nv_bfloat162 ret_val;
  211. ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
  212. ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
  213. return ret_val;
  214. #endif
  215. }
  216. #endif
  217. template <>
  218. DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
  219. {
  220. #if __CUDA_ARCH__ >= 800
  221. return __hmin2(lhs, rhs);
  222. #else
  223. __half2 ret_val;
  224. ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x;
  225. ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y;
  226. return ret_val;
  227. #endif
  228. }
  229. template <>
  230. DS_D_INLINE int32_t element<ROpType::Add>(const int32_t lhs, const int32_t rhs)
  231. {
  232. return lhs + rhs;
  233. }
  234. template <>
  235. DS_D_INLINE int32_t element<ROpType::Max>(const int32_t lhs, const int32_t rhs)
  236. {
  237. return (lhs > rhs) ? lhs : rhs;
  238. }
  239. template <>
  240. DS_D_INLINE int32_t element<ROpType::Min>(const int32_t lhs, const int32_t rhs)
  241. {
  242. return (lhs < rhs) ? lhs : rhs;
  243. }
  244. template <>
  245. DS_D_INLINE uint32_t element<ROpType::Add>(const uint32_t lhs, const uint32_t rhs)
  246. {
  247. return lhs + rhs;
  248. }
  249. template <>
  250. DS_D_INLINE uint32_t element<ROpType::Max>(const uint32_t lhs, const uint32_t rhs)
  251. {
  252. return (lhs > rhs) ? lhs : rhs;
  253. }
  254. template <>
  255. DS_D_INLINE uint32_t element<ROpType::Min>(const uint32_t lhs, const uint32_t rhs)
  256. {
  257. return (lhs < rhs) ? lhs : rhs;
  258. }
  259. template <>
  260. DS_D_INLINE int64_t element<ROpType::Add>(const int64_t lhs, const int64_t rhs)
  261. {
  262. return lhs + rhs;
  263. }
  264. template <>
  265. DS_D_INLINE int64_t element<ROpType::Max>(const int64_t lhs, const int64_t rhs)
  266. {
  267. return (lhs > rhs) ? lhs : rhs;
  268. }
  269. template <>
  270. DS_D_INLINE int64_t element<ROpType::Min>(const int64_t lhs, const int64_t rhs)
  271. {
  272. return (lhs < rhs) ? lhs : rhs;
  273. }
  274. /*
  275. Reduction initialization primitives
  276. */
  277. template <>
  278. DS_D_INLINE float init<ROpType::Add>()
  279. {
  280. return 0.0f;
  281. }
  282. template <>
  283. DS_D_INLINE double init<ROpType::Add>()
  284. {
  285. return (double)0.0f;
  286. }
  287. template <>
  288. DS_D_INLINE float init<ROpType::Min>()
  289. {
  290. // Positive infinity
  291. return INFINITY;
  292. }
  293. template <>
  294. DS_D_INLINE float init<ROpType::Max>()
  295. {
  296. // Negative infinity
  297. return -INFINITY;
  298. }
  299. template <>
  300. DS_D_INLINE __half init<ROpType::Add>()
  301. {
  302. constexpr __half_raw zero = {0x0000};
  303. return __half(zero);
  304. }
  305. template <>
  306. DS_D_INLINE __half init<ROpType::Min>()
  307. {
  308. constexpr __half_raw inf = {0x7C00};
  309. return __half(inf);
  310. }
  311. template <>
  312. DS_D_INLINE __half init<ROpType::Max>()
  313. {
  314. constexpr __half_raw neg_inf = {0xFC00};
  315. return __half(neg_inf);
  316. }
  317. #ifdef BF16_AVAILABLE
  318. template <>
  319. DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
  320. {
  321. constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
  322. return __nv_bfloat16(neg_inf);
  323. }
  324. #endif
  325. template <>
  326. DS_D_INLINE __half2 init<ROpType::Add>()
  327. {
  328. #ifdef __HIP_PLATFORM_AMD__
  329. return __half2{_Float16_2{0x0000, 0x0000}};
  330. #else
  331. constexpr __half2_raw zero = {0x0000, 0x0000};
  332. return __half2(zero);
  333. #endif
  334. }
  335. template <>
  336. DS_D_INLINE __half2 init<ROpType::Min>()
  337. {
  338. #ifdef __HIP_PLATFORM_AMD__
  339. return __half2{_Float16_2{0x7C00, 0x7C00}};
  340. #else
  341. constexpr __half2_raw inf = {0x7C00, 0x7C00};
  342. return __half2(inf);
  343. #endif
  344. }
  345. template <>
  346. DS_D_INLINE __half2 init<ROpType::Max>()
  347. {
  348. #ifdef __HIP_PLATFORM_AMD__
  349. return __half2{_Float16_2{0xFC00, 0xFC00}};
  350. #else
  351. constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
  352. return __half2(neg_inf);
  353. #endif
  354. }
  355. template <>
  356. DS_D_INLINE int32_t init<ROpType::Add>()
  357. {
  358. return 0;
  359. }
  360. template <>
  361. DS_D_INLINE int32_t init<ROpType::Min>()
  362. {
  363. return 0x7FFFFFFF;
  364. }
  365. template <>
  366. DS_D_INLINE int32_t init<ROpType::Max>()
  367. {
  368. return 0x80000000;
  369. }
  370. template <>
  371. DS_D_INLINE uint32_t init<ROpType::Add>()
  372. {
  373. return 0;
  374. }
  375. template <>
  376. DS_D_INLINE uint32_t init<ROpType::Min>()
  377. {
  378. return 0xFFFFFFFF;
  379. }
  380. template <>
  381. DS_D_INLINE uint32_t init<ROpType::Max>()
  382. {
  383. return 0;
  384. }
  385. template <>
  386. DS_D_INLINE int64_t init<ROpType::Add>()
  387. {
  388. return 0;
  389. }
  390. template <>
  391. DS_D_INLINE int64_t init<ROpType::Min>()
  392. {
  393. return 0x7FFFFFFFFFFFFFFF;
  394. }
  395. template <>
  396. DS_D_INLINE int64_t init<ROpType::Max>()
  397. {
  398. return 0x8000000000000000;
  399. }
  400. template <>
  401. DS_D_INLINE uint64_t init<ROpType::Add>()
  402. {
  403. return 0;
  404. }
  405. template <>
  406. DS_D_INLINE uint64_t init<ROpType::Min>()
  407. {
  408. return 0xFFFFFFFFFFFFFFFF;
  409. }
  410. template <>
  411. DS_D_INLINE uint64_t init<ROpType::Max>()
  412. {
  413. return 0;
  414. }
  415. template <ROpType Op, typename T>
  416. DS_D_INLINE void init(T* data)
  417. {
  418. data[0] = init<Op, T>();
  419. }
  420. template <ROpType Op1, ROpType Op2, typename T>
  421. DS_D_INLINE void init(T* data)
  422. {
  423. data[0] = init<Op1, T>();
  424. data[1] = init<Op2, T>();
  425. }
  426. template <ROpType Op1, ROpType Op2, ROpType Op3, typename T>
  427. DS_D_INLINE void init(T* data)
  428. {
  429. data[0] = init<Op1, T>();
  430. data[1] = init<Op2, T>();
  431. data[2] = init<Op3, T>();
  432. }
  433. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, typename T>
  434. DS_D_INLINE void init(T* data)
  435. {
  436. data[0] = init<Op1, T>();
  437. data[1] = init<Op2, T>();
  438. data[2] = init<Op3, T>();
  439. data[3] = init<Op4, T>();
  440. }
  441. /*
  442. Warp reduction primitives
  443. `reduction_width` is an unsafe template parameter, that is that
  444. when using `reduction_width` < hw_warp_size the warp is partitioned
  445. into `hw_warp_size` / `reduction_width` groups of partial sums.
  446. If someone can figure out how to use variadic templates in a reasonable way
  447. here (fold is C++17 only and I don't think helps and recursion feels like
  448. huge overkill that harms readability) that would be wonderful.
  449. */
  450. template <typename T, ROpType Op, int reduce_width = hw_warp_size>
  451. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
  452. {
  453. #pragma unroll
  454. for (int i = 1; i < reduce_width; i *= 2) {
  455. data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i));
  456. }
  457. }
  458. template <typename T, ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
  459. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
  460. {
  461. #pragma unroll
  462. for (int i = 1; i < reduce_width; i *= 2) {
  463. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  464. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  465. }
  466. }
  467. template <typename T, ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
  468. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
  469. {
  470. #pragma unroll
  471. for (int i = 1; i < reduce_width; i *= 2) {
  472. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  473. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  474. data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
  475. }
  476. }
  477. template <typename T,
  478. ROpType Op1,
  479. ROpType Op2,
  480. ROpType Op3,
  481. ROpType Op4,
  482. int reduce_width = hw_warp_size>
  483. DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
  484. {
  485. #pragma unroll
  486. for (int i = 1; i < reduce_width; i *= 2) {
  487. data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
  488. data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
  489. data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
  490. data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i));
  491. }
  492. }
  493. /*
  494. Implementation for primary block reduction that serves both `block` and
  495. `partitioned_block`.
  496. Total warps refers to the reduction width of the reduction, not
  497. the number of warps in the block (which may exceed that
  498. if the block is partitioned or if we do a conservative bound at
  499. compile time).
  500. */
  501. template <typename T, int total_warps, ROpType... Ops>
  502. DS_D_INLINE void _block(cg::thread_block& tb,
  503. cg::thread_block_tile<hw_warp_size>& warp_arg,
  504. T* data)
  505. {
  506. constexpr int elems = sizeof...(Ops);
  507. constexpr int bytes = sizeof(T);
  508. // Unused when `partition_size == 1` or total_warps == 1
  509. __shared__ T reduce_buffer[max_warps * elems];
  510. #ifdef __HIP_PLATFORM_AMD__
  511. const int total_threads = blockDim.x * blockDim.y * blockDim.z;
  512. const int running_warps = total_threads / hw_warp_size;
  513. #else
  514. const int running_warps = warp_arg.meta_group_size();
  515. #endif
  516. // Always perform warp-scope reduction
  517. _warp<T, Ops...>(warp_arg, data);
  518. // If max_warps == 1 let's skip the runtime check
  519. if (total_warps != 1) {
  520. if (warp_arg.thread_rank() == 0) {
  521. #pragma unroll
  522. for (int i = 0; i < elems; i++) {
  523. mem_access::store_shared<bytes>(reduce_buffer + elems * _warp_rank() + i, data + i);
  524. }
  525. }
  526. // Synchronization inside block-uniform conditional is safe
  527. tb.sync();
  528. if (_warp_rank() == 0) {
  529. if (warp_arg.thread_rank() < running_warps) {
  530. #pragma unroll
  531. for (int i = 0; i < elems; i++) {
  532. mem_access::load_shared<bytes>(
  533. data + i, reduce_buffer + elems * warp_arg.thread_rank() + i);
  534. }
  535. } else {
  536. init<Ops...>(data);
  537. }
  538. _warp<T, Ops..., total_warps>(warp_arg, data);
  539. #pragma unroll
  540. for (int i = 0; i < elems; i++) {
  541. mem_access::store_shared<bytes>(reduce_buffer + elems * warp_arg.thread_rank() + i,
  542. data + i);
  543. }
  544. }
  545. // Synchronization inside block-uniform conditional is safe
  546. tb.sync();
  547. #pragma unroll
  548. for (int i = 0; i < elems; i++) {
  549. mem_access::load_shared<bytes>(data + i, reduce_buffer + _warp_rank() * elems + i);
  550. }
  551. }
  552. }
  553. /*
  554. Main API implementations. For the most part, they just convert the individual
  555. variables into arrays, which makes working with them easier with a single
  556. implementation. In theory, we could use the `_block` implementation as another
  557. option, but the nature of using a pointer is a little less safe and this allows
  558. us to obfuscate the details of the partitioned implementation.
  559. */
  560. template <ROpType Op, int warp_bound>
  561. DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
  562. {
  563. _block<float, warp_bound, Op>(tb, warp, &val);
  564. }
  565. template <ROpType Op1, ROpType Op2, int warp_bound>
  566. DS_D_INLINE void block(cg::thread_block& tb,
  567. cg::thread_block_tile<hw_warp_size>& warp,
  568. float& val1,
  569. float& val2)
  570. {
  571. float data[2] = {val1, val2};
  572. _block<float, warp_bound, Op1, Op2>(tb, warp, data);
  573. val1 = data[0];
  574. val2 = data[1];
  575. }
  576. template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound>
  577. DS_D_INLINE void block(cg::thread_block& tb,
  578. cg::thread_block_tile<hw_warp_size>& warp,
  579. float& val1,
  580. float& val2,
  581. float& val3)
  582. {
  583. float data[3] = {val1, val2, val3};
  584. _block<float, warp_bound, Op1, Op2, Op3>(tb, warp, data);
  585. val1 = data[0];
  586. val2 = data[1];
  587. val3 = data[2];
  588. }
  589. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound>
  590. DS_D_INLINE void block(cg::thread_block& tb,
  591. cg::thread_block_tile<hw_warp_size>& warp,
  592. float& val1,
  593. float& val2,
  594. float& val3,
  595. float& val4)
  596. {
  597. float data[4] = {val1, val2, val3, val4};
  598. _block<float, warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
  599. val1 = data[0];
  600. val2 = data[1];
  601. val3 = data[2];
  602. val4 = data[3];
  603. }
  604. /*
  605. Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order
  606. to shorten block scale reduction length.
  607. */
  608. template <ROpType Op, int num_threads>
  609. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  610. cg::thread_block_tile<hw_warp_size>& warp,
  611. float& val)
  612. {
  613. if (num_threads <= hw_warp_size) {
  614. _warp<float, Op, num_threads>(warp, &val);
  615. } else {
  616. constexpr int num_warps = num_threads / hw_warp_size;
  617. _block<float, num_warps, Op>(tb, warp, &val);
  618. }
  619. }
  620. template <ROpType Op1, ROpType Op2, int num_threads>
  621. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  622. cg::thread_block_tile<hw_warp_size>& warp,
  623. float& val1,
  624. float& val2)
  625. {
  626. float data[2] = {val1, val2};
  627. if (num_threads <= hw_warp_size) {
  628. _warp<float, Op1, Op2, num_threads>(warp, data);
  629. } else {
  630. constexpr int num_warps = num_threads / hw_warp_size;
  631. _block<float, num_warps, Op1, Op2>(tb, warp, data);
  632. }
  633. val1 = data[0];
  634. val2 = data[1];
  635. }
  636. template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
  637. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  638. cg::thread_block_tile<hw_warp_size>& warp,
  639. float& val1,
  640. float& val2,
  641. float& val3)
  642. {
  643. float data[3] = {val1, val2, val3};
  644. if (num_threads <= hw_warp_size) {
  645. _warp<float, Op1, Op2, Op3, num_threads>(warp, data);
  646. } else {
  647. constexpr int num_warps = num_threads / hw_warp_size;
  648. _block<float, num_warps, Op1, Op2, Op3>(tb, warp, data);
  649. }
  650. val1 = data[0];
  651. val2 = data[1];
  652. val3 = data[2];
  653. }
  654. template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
  655. DS_D_INLINE void partitioned_block(cg::thread_block& tb,
  656. cg::thread_block_tile<hw_warp_size>& warp,
  657. float& val1,
  658. float& val2,
  659. float& val3,
  660. float& val4)
  661. {
  662. float data[4] = {val1, val2, val3, val4};
  663. if (num_threads <= hw_warp_size) {
  664. _warp<float, Op1, Op2, Op3, Op4, num_threads>(warp, data);
  665. } else {
  666. constexpr int num_warps = num_threads / hw_warp_size;
  667. _block<float, num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
  668. }
  669. val1 = data[0];
  670. val2 = data[1];
  671. val3 = data[2];
  672. val4 = data[3];
  673. }
  674. /*
  675. Arg-reduce is a specialization of the above. We only support this with a single reduction
  676. parameter. This only works for max/min reductions.
  677. */
  678. __align__(8) struct IdxReduceResult {
  679. /*
  680. NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits
  681. and the val is the most significant. Changing the order of this declaration
  682. will break the code.
  683. */
  684. int idx;
  685. float val;
  686. };
  687. template <ROpType Op, int warpBound>
  688. DS_D_INLINE IdxReduceResult
  689. idx_reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float val, int idx)
  690. {
  691. IdxReduceResult res = {idx, val};
  692. // Clear out the nan. This shouldn't be an issue for our initial applications
  693. if (isnan(val)) res.val = init<Op>();
  694. // Can do float compares as integers. By packing the index into the lower bits
  695. // we can just do a single int64 rather than a branch, compare, and select.
  696. // One side benefit of this is that it is by nature a stable algorithm and
  697. // will always bias ties to the higher index.
  698. int64_t* res_as_int = reinterpret_cast<int64_t*>(&res);
  699. // The way floating point compare works is normally to perform a sign comparison
  700. // and if they match, then do a comparison of the rest of the bits as unsigned
  701. // integers. Since we are bundling these, that means for negative values we need
  702. // to reverse the sort order, which we can do with an XOR.
  703. if (val < 0) { *res_as_int ^= 0x7fffffff00000000; }
  704. _block<int64_t, warpBound, Op>(tb, warp, res_as_int);
  705. // Sign bit is preserved, so we can check if we need to invert the mantissa back
  706. if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; }
  707. return res;
  708. }
  709. } // namespace reduce