conversion_utils.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include "ds_kernel_utils.h"
  6. #include <stdint.h>
  7. #ifdef BF16_AVAILABLE
  8. #include <cuda_bf16.h>
  9. #endif
  10. namespace conversion {
  11. // Basic primitive for constructing conversions
  12. template <typename TO, typename FROM>
  13. DS_D_INLINE TO to(FROM val)
  14. {
  15. return to(val);
  16. }
  17. // Specializations
  18. /********************* Identity Conversions *********************/
  19. /*
  20. Identity conversions are useful in templated functions where we might have
  21. a fixed destination type. For example, I might have a kernel that accepts
  22. __half, __nv_bfloat16, and float but always want to do the core computation
  23. at floating point:
  24. T mem_value = input[idx];
  25. float compute_value = conversion::to<float, T>(mem_value);
  26. In practice, we should be able to elide the second template parameter:
  27. float compute_val = conversion::to<float>(mem_value);
  28. In this case, we need an implementation to handle the T = float case
  29. NOTE: The type inferencing system appears to be unable to handle inferring the first
  30. template parameter, even in the trivial case.
  31. */
  32. // Floating point types
  33. template <>
  34. DS_D_INLINE double to(double val)
  35. {
  36. return val;
  37. }
  38. template <>
  39. DS_D_INLINE float to(float val)
  40. {
  41. return val;
  42. }
  43. template <>
  44. DS_D_INLINE __half to(__half val)
  45. {
  46. return val;
  47. }
  48. #ifdef BF16_AVAILABLE
  49. template <>
  50. DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
  51. {
  52. return val;
  53. }
  54. #endif
  55. // Integer types
  56. template <>
  57. DS_D_INLINE int8_t to(int8_t val)
  58. {
  59. return val;
  60. }
  61. template <>
  62. DS_D_INLINE uint8_t to(uint8_t val)
  63. {
  64. return val;
  65. }
  66. template <>
  67. DS_D_INLINE int16_t to(int16_t val)
  68. {
  69. return val;
  70. }
  71. template <>
  72. DS_D_INLINE uint16_t to(uint16_t val)
  73. {
  74. return val;
  75. }
  76. template <>
  77. DS_D_INLINE int32_t to(int32_t val)
  78. {
  79. return val;
  80. }
  81. template <>
  82. DS_D_INLINE uint32_t to(uint32_t val)
  83. {
  84. return val;
  85. }
  86. template <>
  87. DS_D_INLINE int64_t to(int64_t val)
  88. {
  89. return val;
  90. }
  91. template <>
  92. DS_D_INLINE uint64_t to(uint64_t val)
  93. {
  94. return val;
  95. }
  96. // TODO: evaluate if we want bools
  97. /********************* To Double Conversions *********************/
  98. // * to double variants
  99. // Would normally like to not use C cast, but this is an important enough conversion
  100. // to keep
  101. template <>
  102. DS_D_INLINE double to(float val)
  103. {
  104. #ifdef PTX_AVAILABLE
  105. double ret_val;
  106. asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val));
  107. return ret_val;
  108. #else
  109. return double(val);
  110. #endif
  111. }
  112. // Note: there is a CVT instruction for __half -> double, but there's no inline interface
  113. // for passing a single half value
  114. template <>
  115. DS_D_INLINE double to(__half val)
  116. {
  117. return to<double>(__half2float(val));
  118. }
  119. template <>
  120. DS_D_INLINE double to(int64_t val)
  121. {
  122. return __ll2double_rn(val);
  123. }
  124. template <>
  125. DS_D_INLINE double to(int32_t val)
  126. {
  127. return __int2double_rn(val);
  128. }
  129. template <>
  130. DS_D_INLINE double to(int16_t val)
  131. {
  132. return __int2double_rn(val);
  133. }
  134. template <>
  135. DS_D_INLINE double to(int8_t val)
  136. {
  137. return __int2double_rn(val);
  138. }
  139. template <>
  140. DS_D_INLINE double to(uint64_t val)
  141. {
  142. return __ull2double_rn(val);
  143. }
  144. template <>
  145. DS_D_INLINE double to(uint32_t val)
  146. {
  147. return __uint2double_rn(val);
  148. }
  149. template <>
  150. DS_D_INLINE double to(uint16_t val)
  151. {
  152. return __uint2double_rn(val);
  153. }
  154. template <>
  155. DS_D_INLINE double to(uint8_t val)
  156. {
  157. return __uint2double_rn(val);
  158. }
  159. // Same applies here
  160. #ifdef BF16_AVAILABLE
  161. template <>
  162. DS_D_INLINE double to(__nv_bfloat16 val)
  163. {
  164. return to<double>(__bfloat162float(val));
  165. }
  166. #endif
  167. /********************* To Float Conversions *********************/
  168. template <>
  169. DS_D_INLINE float to(double val)
  170. {
  171. return __double2float_rn(val);
  172. }
  173. template <>
  174. DS_D_INLINE float to(__half val)
  175. {
  176. return __half2float(val);
  177. }
  178. template <>
  179. DS_D_INLINE float to(int64_t val)
  180. {
  181. return __ll2float_rn(val);
  182. }
  183. template <>
  184. DS_D_INLINE float to(int32_t val)
  185. {
  186. return __int2float_rn(val);
  187. }
  188. template <>
  189. DS_D_INLINE float to(int16_t val)
  190. {
  191. return __int2float_rn(val);
  192. }
  193. template <>
  194. DS_D_INLINE float to(int8_t val)
  195. {
  196. return __int2float_rn(val);
  197. }
  198. template <>
  199. DS_D_INLINE float to(uint64_t val)
  200. {
  201. return __ull2float_rn(val);
  202. }
  203. template <>
  204. DS_D_INLINE float to(uint32_t val)
  205. {
  206. return __uint2float_rn(val);
  207. }
  208. template <>
  209. DS_D_INLINE float to(uint16_t val)
  210. {
  211. return __uint2float_rn(val);
  212. }
  213. template <>
  214. DS_D_INLINE float to(uint8_t val)
  215. {
  216. return __uint2float_rn(val);
  217. }
  218. #ifdef BF16_AVAILABLE
  219. template <>
  220. DS_D_INLINE float to(__nv_bfloat16 val)
  221. {
  222. return __bfloat162float(val);
  223. }
  224. #endif
  225. /********************* To Float2 Conversions *********************/
  226. template <>
  227. DS_D_INLINE float2 to(__half2 val)
  228. {
  229. return __half22float2(val);
  230. }
  231. #ifdef BF16_AVAILABLE
  232. template <>
  233. DS_D_INLINE float2 to(__nv_bfloat162 val)
  234. {
  235. return __bfloat1622float2(val);
  236. }
  237. #endif
  238. /********************* To Half Conversions *********************/
  239. template <>
  240. DS_D_INLINE __half to(double val)
  241. {
  242. #ifdef __HIP_PLATFORM_HCC__
  243. float val_f = __double2float_rn(val);
  244. return __float2half(val_f);
  245. #else
  246. return __double2half(val);
  247. #endif
  248. }
  249. template <>
  250. DS_D_INLINE __half to(float val)
  251. {
  252. return __float2half(val);
  253. }
  254. template <>
  255. DS_D_INLINE __half to(int64_t val)
  256. {
  257. return __ll2half_rn(val);
  258. }
  259. template <>
  260. DS_D_INLINE __half to(int32_t val)
  261. {
  262. return __int2half_rn(val);
  263. }
  264. template <>
  265. DS_D_INLINE __half to(int16_t val)
  266. {
  267. return __short2half_rn(val);
  268. }
  269. template <>
  270. DS_D_INLINE __half to(int8_t val)
  271. {
  272. return __int2half_rn(val);
  273. }
  274. template <>
  275. DS_D_INLINE __half to(uint64_t val)
  276. {
  277. return __ull2half_rn(val);
  278. }
  279. template <>
  280. DS_D_INLINE __half to(uint32_t val)
  281. {
  282. return __uint2half_rn(val);
  283. }
  284. template <>
  285. DS_D_INLINE __half to(uint16_t val)
  286. {
  287. return __ushort2half_rn(val);
  288. }
  289. template <>
  290. DS_D_INLINE __half to(uint8_t val)
  291. {
  292. return __uint2half_rn(val);
  293. }
  294. #ifdef BF16_AVAILABLE
  295. // No direct conversion
  296. template <>
  297. DS_D_INLINE __half to(__nv_bfloat16 val)
  298. {
  299. return to<__half>(to<float>(val));
  300. }
  301. #endif
  302. /********************* To Half2 Conversions *********************/
  303. template <>
  304. DS_D_INLINE __half2 to(float2 val)
  305. {
  306. return __float22half2_rn(val);
  307. }
  308. template <>
  309. DS_D_INLINE __half2 to(float val)
  310. {
  311. return __float2half2_rn(val);
  312. }
  313. #ifdef BF16_AVAILABLE
  314. // No direct conversion
  315. template <>
  316. DS_D_INLINE __half2 to(__nv_bfloat162 val)
  317. {
  318. return to<__half2>(to<float2>(val));
  319. }
  320. #endif
  321. /********************* To BF16 Conversions *********************/
  322. #ifdef BF16_AVAILABLE
  323. template <>
  324. DS_D_INLINE __nv_bfloat16 to(double val)
  325. {
  326. return __double2bfloat16(val);
  327. }
  328. template <>
  329. DS_D_INLINE __nv_bfloat16 to(float val)
  330. {
  331. return __float2bfloat16(val);
  332. }
  333. template <>
  334. DS_D_INLINE __nv_bfloat16 to(int64_t val)
  335. {
  336. return __ll2bfloat16_rn(val);
  337. }
  338. template <>
  339. DS_D_INLINE __nv_bfloat16 to(int32_t val)
  340. {
  341. return __int2bfloat16_rn(val);
  342. }
  343. template <>
  344. DS_D_INLINE __nv_bfloat16 to(int16_t val)
  345. {
  346. return __short2bfloat16_rn(val);
  347. }
  348. template <>
  349. DS_D_INLINE __nv_bfloat16 to(int8_t val)
  350. {
  351. return __int2bfloat16_rn(val);
  352. }
  353. template <>
  354. DS_D_INLINE __nv_bfloat16 to(uint64_t val)
  355. {
  356. return __ull2bfloat16_rn(val);
  357. }
  358. template <>
  359. DS_D_INLINE __nv_bfloat16 to(uint32_t val)
  360. {
  361. return __uint2bfloat16_rn(val);
  362. }
  363. template <>
  364. DS_D_INLINE __nv_bfloat16 to(uint16_t val)
  365. {
  366. return __ushort2bfloat16_rn(val);
  367. }
  368. template <>
  369. DS_D_INLINE __nv_bfloat16 to(uint8_t val)
  370. {
  371. return __uint2bfloat16_rn(val);
  372. }
  373. #endif
  374. /********************* To BF162 Conversions *********************/
  375. #ifdef BF16_AVAILABLE
  376. template <>
  377. DS_D_INLINE __nv_bfloat162 to(float2 val)
  378. {
  379. return __float22bfloat162_rn(val);
  380. }
  381. template <>
  382. DS_D_INLINE __nv_bfloat162 to(float val)
  383. {
  384. return __float2bfloat162_rn(val);
  385. }
  386. template <>
  387. DS_D_INLINE __nv_bfloat162 to(__half2 val)
  388. {
  389. return to<__nv_bfloat162>(to<float2>(val));
  390. }
  391. #endif
  392. /********************* To INT64_T Conversions *********************/
  393. template <>
  394. DS_D_INLINE int64_t to(double val)
  395. {
  396. return __double2ll_rn(val);
  397. }
  398. template <>
  399. DS_D_INLINE int64_t to(float val)
  400. {
  401. return __float2ll_rn(val);
  402. }
  403. template <>
  404. DS_D_INLINE int64_t to(__half val)
  405. {
  406. return __half2ll_rn(val);
  407. }
  408. // No direct support for integer casts at the C++ level and I don't feel they're so important
  409. // to demand an PTX at this time
  410. #ifdef BF16_AVAILABLE
  411. template <>
  412. DS_D_INLINE int64_t to(__nv_bfloat16 val)
  413. {
  414. return __bfloat162ll_rn(val);
  415. }
  416. #endif
  417. /********************* To INT32_T Conversions *********************/
  418. template <>
  419. DS_D_INLINE int32_t to(double val)
  420. {
  421. return __double2int_rn(val);
  422. }
  423. template <>
  424. DS_D_INLINE int32_t to(float val)
  425. {
  426. return __float2int_rn(val);
  427. }
  428. template <>
  429. DS_D_INLINE int32_t to(__half val)
  430. {
  431. return __half2int_rn(val);
  432. }
  433. // No direct support for integer casts at the C++ level and I don't feel they're so important
  434. // to demand an PTX at this time
  435. #ifdef BF16_AVAILABLE
  436. template <>
  437. DS_D_INLINE int32_t to(__nv_bfloat16 val)
  438. {
  439. return __bfloat162int_rn(val);
  440. }
  441. #endif
  442. /********************* To INT16_T Conversions *********************/
  443. template <>
  444. DS_D_INLINE int16_t to(double val)
  445. {
  446. return __double2int_rn(val);
  447. }
  448. template <>
  449. DS_D_INLINE int16_t to(float val)
  450. {
  451. return __float2int_rn(val);
  452. }
  453. template <>
  454. DS_D_INLINE int16_t to(__half val)
  455. {
  456. return __half2int_rn(val);
  457. }
  458. // No direct support for integer casts at the C++ level and I don't feel they're so important
  459. // to demand an PTX at this time
  460. #ifdef BF16_AVAILABLE
  461. template <>
  462. DS_D_INLINE int16_t to(__nv_bfloat16 val)
  463. {
  464. return __bfloat162int_rn(val);
  465. }
  466. #endif
  467. /********************* To INT8_T Conversions *********************/
  468. template <>
  469. DS_D_INLINE int8_t to(double val)
  470. {
  471. return __double2int_rn(val);
  472. }
  473. template <>
  474. DS_D_INLINE int8_t to(float val)
  475. {
  476. return __float2int_rn(val);
  477. }
  478. template <>
  479. DS_D_INLINE int8_t to(__half val)
  480. {
  481. return __half2int_rn(val);
  482. }
  483. // No direct support for integer casts at the C++ level and I don't feel they're so important
  484. // to demand an PTX at this time
  485. #ifdef BF16_AVAILABLE
  486. template <>
  487. DS_D_INLINE int8_t to(__nv_bfloat16 val)
  488. {
  489. return __bfloat162int_rn(val);
  490. }
  491. #endif
  492. /********************* To UINT64_T Conversions *********************/
  493. template <>
  494. DS_D_INLINE uint64_t to(double val)
  495. {
  496. return __double2ull_rn(val);
  497. }
  498. template <>
  499. DS_D_INLINE uint64_t to(float val)
  500. {
  501. return __float2ull_rn(val);
  502. }
  503. template <>
  504. DS_D_INLINE uint64_t to(__half val)
  505. {
  506. return __half2ull_rn(val);
  507. }
  508. // No direct support for integer casts at the C++ level and I don't feel they're so important
  509. // to demand an PTX at this time
  510. #ifdef BF16_AVAILABLE
  511. template <>
  512. DS_D_INLINE uint64_t to(__nv_bfloat16 val)
  513. {
  514. return __bfloat162ull_rn(val);
  515. }
  516. #endif
  517. /********************* To UINT32_T Conversions *********************/
  518. template <>
  519. DS_D_INLINE uint32_t to(double val)
  520. {
  521. return __double2uint_rn(val);
  522. }
  523. template <>
  524. DS_D_INLINE uint32_t to(float val)
  525. {
  526. return __float2uint_rn(val);
  527. }
  528. template <>
  529. DS_D_INLINE uint32_t to(__half val)
  530. {
  531. return __half2uint_rn(val);
  532. }
  533. // No direct support for integer casts at the C++ level and I don't feel they're so important
  534. // to demand an PTX at this time
  535. #ifdef BF16_AVAILABLE
  536. template <>
  537. DS_D_INLINE uint32_t to(__nv_bfloat16 val)
  538. {
  539. return __bfloat162uint_rn(val);
  540. }
  541. #endif
  542. /********************* To UINT16_T Conversions *********************/
  543. template <>
  544. DS_D_INLINE uint16_t to(double val)
  545. {
  546. return __double2uint_rn(val);
  547. }
  548. template <>
  549. DS_D_INLINE uint16_t to(float val)
  550. {
  551. return __float2uint_rn(val);
  552. }
  553. template <>
  554. DS_D_INLINE uint16_t to(__half val)
  555. {
  556. return __half2uint_rn(val);
  557. }
  558. // No direct support for integer casts at the C++ level and I don't feel they're so important
  559. // to demand an PTX at this time
  560. #ifdef BF16_AVAILABLE
  561. template <>
  562. DS_D_INLINE uint16_t to(__nv_bfloat16 val)
  563. {
  564. return __bfloat162uint_rn(val);
  565. }
  566. #endif
  567. /********************* To UINT8_T Conversions *********************/
  568. template <>
  569. DS_D_INLINE uint8_t to(double val)
  570. {
  571. return __double2uint_rn(val);
  572. }
  573. template <>
  574. DS_D_INLINE uint8_t to(float val)
  575. {
  576. return __float2uint_rn(val);
  577. }
  578. template <>
  579. DS_D_INLINE uint8_t to(__half val)
  580. {
  581. return __half2uint_rn(val);
  582. }
  583. // No direct support for integer casts at the C++ level and I don't feel they're so important
  584. // to demand an PTX at this time
  585. #ifdef BF16_AVAILABLE
  586. template <>
  587. DS_D_INLINE uint8_t to(__nv_bfloat16 val)
  588. {
  589. return __bfloat162uint_rn(val);
  590. }
  591. #endif
  592. } // namespace conversion