conversion_utils.h 12 KB


  1. /*
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. */
  4. #pragma once
  5. #include "ds_kernel_utils.h"
  6. #include <cuda_fp16.h>
  7. #include <stdint.h>
  8. #ifdef BF16_AVAILABLE
  9. #include <cuda_bf16.h>
  10. #endif
  11. namespace conversion {
  12. // Basic primitive for constructing conversions
  13. template <typename TO, typename FROM>
  14. DS_D_INLINE TO to(FROM val)
  15. {
  16. return to(val);
  17. }
  18. // Specializations
  19. /********************* Identity Conversions *********************/
  20. /*
  21. Identity conversions are useful in templated functions where we might have
  22. a fixed destination type. For example, I might have a kernel that accepts
  23. __half, __nv_bfloat16, and float but always want to do the core computation
  24. at floating point:
  25. T mem_value = input[idx];
  26. float compute_value = conversion::to<float, T>(mem_value);
  27. In practice, we should be able to elide the second template parameter:
  28. float compute_val = conversion::to<float>(mem_value);
  29. In this case, we need an implementation to handle the T = float case
  30. NOTE: The type inferencing system appears to be unable to handle inferring the first
  31. template parameter, even in the trivial case.
  32. */
  33. // Floating point types
  34. template <>
  35. DS_D_INLINE double to(double val)
  36. {
  37. return val;
  38. }
  39. template <>
  40. DS_D_INLINE float to(float val)
  41. {
  42. return val;
  43. }
  44. template <>
  45. DS_D_INLINE __half to(__half val)
  46. {
  47. return val;
  48. }
  49. #ifdef BF16_AVAILABLE
  50. template <>
  51. DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
  52. {
  53. return val;
  54. }
  55. #endif
  56. // Integer types
  57. template <>
  58. DS_D_INLINE int8_t to(int8_t val)
  59. {
  60. return val;
  61. }
  62. template <>
  63. DS_D_INLINE uint8_t to(uint8_t val)
  64. {
  65. return val;
  66. }
  67. template <>
  68. DS_D_INLINE int16_t to(int16_t val)
  69. {
  70. return val;
  71. }
  72. template <>
  73. DS_D_INLINE uint16_t to(uint16_t val)
  74. {
  75. return val;
  76. }
  77. template <>
  78. DS_D_INLINE int32_t to(int32_t val)
  79. {
  80. return val;
  81. }
  82. template <>
  83. DS_D_INLINE uint32_t to(uint32_t val)
  84. {
  85. return val;
  86. }
  87. template <>
  88. DS_D_INLINE int64_t to(int64_t val)
  89. {
  90. return val;
  91. }
  92. template <>
  93. DS_D_INLINE uint64_t to(uint64_t val)
  94. {
  95. return val;
  96. }
  97. // TODO: evaluate if we want bools
  98. /********************* To Double Conversions *********************/
  99. // * to double variants
  100. // Would normally like to not use C cast, but this is an important enough conversion
  101. // to keep
  102. template <>
  103. DS_D_INLINE double to(float val)
  104. {
  105. #ifdef PTX_AVAILABLE
  106. double ret_val;
  107. asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val));
  108. return ret_val;
  109. #else
  110. return double(val);
  111. #endif
  112. }
  113. // Note: there is a CVT instruction for __half -> double, but there's no inline interface
  114. // for passing a single half value
  115. template <>
  116. DS_D_INLINE double to(__half val)
  117. {
  118. return to<double>(__half2float(val));
  119. }
  120. template <>
  121. DS_D_INLINE double to(int64_t val)
  122. {
  123. return __ll2double_rn(val);
  124. }
  125. template <>
  126. DS_D_INLINE double to(int32_t val)
  127. {
  128. return __int2double_rn(val);
  129. }
  130. template <>
  131. DS_D_INLINE double to(int16_t val)
  132. {
  133. return __int2double_rn(val);
  134. }
  135. template <>
  136. DS_D_INLINE double to(int8_t val)
  137. {
  138. return __int2double_rn(val);
  139. }
  140. template <>
  141. DS_D_INLINE double to(uint64_t val)
  142. {
  143. return __ull2double_rn(val);
  144. }
  145. template <>
  146. DS_D_INLINE double to(uint32_t val)
  147. {
  148. return __uint2double_rn(val);
  149. }
  150. template <>
  151. DS_D_INLINE double to(uint16_t val)
  152. {
  153. return __uint2double_rn(val);
  154. }
  155. template <>
  156. DS_D_INLINE double to(uint8_t val)
  157. {
  158. return __uint2double_rn(val);
  159. }
  160. // Same applies here
  161. #ifdef BF16_AVAILABLE
  162. template <>
  163. DS_D_INLINE double to(__nv_bfloat16 val)
  164. {
  165. return to<double>(__bfloat162float(val));
  166. }
  167. #endif
  168. /********************* To Float Conversions *********************/
  169. template <>
  170. DS_D_INLINE float to(double val)
  171. {
  172. return __double2float_rn(val);
  173. }
  174. template <>
  175. DS_D_INLINE float to(__half val)
  176. {
  177. return __half2float(val);
  178. }
  179. template <>
  180. DS_D_INLINE float to(int64_t val)
  181. {
  182. return __ll2float_rn(val);
  183. }
  184. template <>
  185. DS_D_INLINE float to(int32_t val)
  186. {
  187. return __int2float_rn(val);
  188. }
  189. template <>
  190. DS_D_INLINE float to(int16_t val)
  191. {
  192. return __int2float_rn(val);
  193. }
  194. template <>
  195. DS_D_INLINE float to(int8_t val)
  196. {
  197. return __int2float_rn(val);
  198. }
  199. template <>
  200. DS_D_INLINE float to(uint64_t val)
  201. {
  202. return __ull2float_rn(val);
  203. }
  204. template <>
  205. DS_D_INLINE float to(uint32_t val)
  206. {
  207. return __uint2float_rn(val);
  208. }
  209. template <>
  210. DS_D_INLINE float to(uint16_t val)
  211. {
  212. return __uint2float_rn(val);
  213. }
  214. template <>
  215. DS_D_INLINE float to(uint8_t val)
  216. {
  217. return __uint2float_rn(val);
  218. }
  219. #ifdef BF16_AVAILABLE
  220. template <>
  221. DS_D_INLINE float to(__nv_bfloat16 val)
  222. {
  223. return __bfloat162float(val);
  224. }
  225. #endif
  226. /********************* To Float2 Conversions *********************/
  227. template <>
  228. DS_D_INLINE float2 to(__half2 val)
  229. {
  230. return __half22float2(val);
  231. }
  232. #ifdef BF16_AVAILABLE
  233. template <>
  234. DS_D_INLINE float2 to(__nv_bfloat162 val)
  235. {
  236. return __bfloat1622float2(val);
  237. }
  238. #endif
  239. /********************* To Half Conversions *********************/
  240. template <>
  241. DS_D_INLINE __half to(double val)
  242. {
  243. return __double2half(val);
  244. }
  245. template <>
  246. DS_D_INLINE __half to(float val)
  247. {
  248. return __float2half(val);
  249. }
  250. template <>
  251. DS_D_INLINE __half to(int64_t val)
  252. {
  253. return __ll2half_rn(val);
  254. }
  255. template <>
  256. DS_D_INLINE __half to(int32_t val)
  257. {
  258. return __int2half_rn(val);
  259. }
  260. template <>
  261. DS_D_INLINE __half to(int16_t val)
  262. {
  263. return __short2half_rn(val);
  264. }
  265. template <>
  266. DS_D_INLINE __half to(int8_t val)
  267. {
  268. return __int2half_rn(val);
  269. }
  270. template <>
  271. DS_D_INLINE __half to(uint64_t val)
  272. {
  273. return __ull2half_rn(val);
  274. }
  275. template <>
  276. DS_D_INLINE __half to(uint32_t val)
  277. {
  278. return __uint2half_rn(val);
  279. }
  280. template <>
  281. DS_D_INLINE __half to(uint16_t val)
  282. {
  283. return __ushort2half_rn(val);
  284. }
  285. template <>
  286. DS_D_INLINE __half to(uint8_t val)
  287. {
  288. return __uint2half_rn(val);
  289. }
  290. #ifdef BF16_AVAILABLE
  291. // No direct conversion
  292. template <>
  293. DS_D_INLINE __half to(__nv_bfloat16 val)
  294. {
  295. return to<__half>(to<float>(val));
  296. }
  297. #endif
  298. /********************* To Half2 Conversions *********************/
  299. template <>
  300. DS_D_INLINE __half2 to(float2 val)
  301. {
  302. return __float22half2_rn(val);
  303. }
  304. #ifdef BF16_AVAILABLE
  305. // No direct conversion
  306. template <>
  307. DS_D_INLINE __half2 to(__nv_bfloat162 val)
  308. {
  309. return to<__half2>(to<float2>(val));
  310. }
  311. #endif
  312. /********************* To BF16 Conversions *********************/
  313. #ifdef BF16_AVAILABLE
  314. template <>
  315. DS_D_INLINE __nv_bfloat16 to(double val)
  316. {
  317. return __double2bfloat16(val);
  318. }
  319. template <>
  320. DS_D_INLINE __nv_bfloat16 to(float val)
  321. {
  322. return __float2bfloat16(val);
  323. }
  324. template <>
  325. DS_D_INLINE __nv_bfloat16 to(int64_t val)
  326. {
  327. return __ll2bfloat16_rn(val);
  328. }
  329. template <>
  330. DS_D_INLINE __nv_bfloat16 to(int32_t val)
  331. {
  332. return __int2bfloat16_rn(val);
  333. }
  334. template <>
  335. DS_D_INLINE __nv_bfloat16 to(int16_t val)
  336. {
  337. return __short2bfloat16_rn(val);
  338. }
  339. template <>
  340. DS_D_INLINE __nv_bfloat16 to(int8_t val)
  341. {
  342. return __int2bfloat16_rn(val);
  343. }
  344. template <>
  345. DS_D_INLINE __nv_bfloat16 to(uint64_t val)
  346. {
  347. return __ull2bfloat16_rn(val);
  348. }
  349. template <>
  350. DS_D_INLINE __nv_bfloat16 to(uint32_t val)
  351. {
  352. return __uint2bfloat16_rn(val);
  353. }
  354. template <>
  355. DS_D_INLINE __nv_bfloat16 to(uint16_t val)
  356. {
  357. return __ushort2bfloat16_rn(val);
  358. }
  359. template <>
  360. DS_D_INLINE __nv_bfloat16 to(uint8_t val)
  361. {
  362. return __uint2bfloat16_rn(val);
  363. }
  364. #endif
  365. /********************* To BF162 Conversions *********************/
  366. #ifdef BF16_AVAILABLE
  367. template <>
  368. DS_D_INLINE __nv_bfloat162 to(float2 val)
  369. {
  370. return __float22bfloat162_rn(val);
  371. }
  372. template <>
  373. DS_D_INLINE __nv_bfloat162 to(__half2 val)
  374. {
  375. return to<__nv_bfloat162>(to<float2>(val));
  376. }
  377. #endif
  378. /********************* To INT64_T Conversions *********************/
  379. template <>
  380. DS_D_INLINE int64_t to(double val)
  381. {
  382. return __double2ll_rn(val);
  383. }
  384. template <>
  385. DS_D_INLINE int64_t to(float val)
  386. {
  387. return __float2ll_rn(val);
  388. }
  389. template <>
  390. DS_D_INLINE int64_t to(__half val)
  391. {
  392. return __half2ll_rn(val);
  393. }
  394. // No direct support for integer casts at the C++ level and I don't feel they're so important
  395. // to demand an PTX at this time
  396. #ifdef BF16_AVAILABLE
  397. template <>
  398. DS_D_INLINE int64_t to(__nv_bfloat16 val)
  399. {
  400. return __bfloat162ll_rn(val);
  401. }
  402. #endif
  403. /********************* To INT32_T Conversions *********************/
  404. template <>
  405. DS_D_INLINE int32_t to(double val)
  406. {
  407. return __double2int_rn(val);
  408. }
  409. template <>
  410. DS_D_INLINE int32_t to(float val)
  411. {
  412. return __float2int_rn(val);
  413. }
  414. template <>
  415. DS_D_INLINE int32_t to(__half val)
  416. {
  417. return __half2int_rn(val);
  418. }
  419. // No direct support for integer casts at the C++ level and I don't feel they're so important
  420. // to demand an PTX at this time
  421. #ifdef BF16_AVAILABLE
  422. template <>
  423. DS_D_INLINE int32_t to(__nv_bfloat16 val)
  424. {
  425. return __bfloat162int_rn(val);
  426. }
  427. #endif
  428. /********************* To INT16_T Conversions *********************/
  429. template <>
  430. DS_D_INLINE int16_t to(double val)
  431. {
  432. return __double2int_rn(val);
  433. }
  434. template <>
  435. DS_D_INLINE int16_t to(float val)
  436. {
  437. return __float2int_rn(val);
  438. }
  439. template <>
  440. DS_D_INLINE int16_t to(__half val)
  441. {
  442. return __half2int_rn(val);
  443. }
  444. // No direct support for integer casts at the C++ level and I don't feel they're so important
  445. // to demand an PTX at this time
  446. #ifdef BF16_AVAILABLE
  447. template <>
  448. DS_D_INLINE int16_t to(__nv_bfloat16 val)
  449. {
  450. return __bfloat162int_rn(val);
  451. }
  452. #endif
  453. /********************* To INT8_T Conversions *********************/
  454. template <>
  455. DS_D_INLINE int8_t to(double val)
  456. {
  457. return __double2int_rn(val);
  458. }
  459. template <>
  460. DS_D_INLINE int8_t to(float val)
  461. {
  462. return __float2int_rn(val);
  463. }
  464. template <>
  465. DS_D_INLINE int8_t to(__half val)
  466. {
  467. return __half2int_rn(val);
  468. }
  469. // No direct support for integer casts at the C++ level and I don't feel they're so important
  470. // to demand an PTX at this time
  471. #ifdef BF16_AVAILABLE
  472. template <>
  473. DS_D_INLINE int8_t to(__nv_bfloat16 val)
  474. {
  475. return __bfloat162int_rn(val);
  476. }
  477. #endif
  478. /********************* To UINT64_T Conversions *********************/
  479. template <>
  480. DS_D_INLINE uint64_t to(double val)
  481. {
  482. return __double2ull_rn(val);
  483. }
  484. template <>
  485. DS_D_INLINE uint64_t to(float val)
  486. {
  487. return __float2ull_rn(val);
  488. }
  489. template <>
  490. DS_D_INLINE uint64_t to(__half val)
  491. {
  492. return __half2ull_rn(val);
  493. }
  494. // No direct support for integer casts at the C++ level and I don't feel they're so important
  495. // to demand an PTX at this time
  496. #ifdef BF16_AVAILABLE
  497. template <>
  498. DS_D_INLINE uint64_t to(__nv_bfloat16 val)
  499. {
  500. return __bfloat162ull_rn(val);
  501. }
  502. #endif
  503. /********************* To UINT32_T Conversions *********************/
  504. template <>
  505. DS_D_INLINE uint32_t to(double val)
  506. {
  507. return __double2uint_rn(val);
  508. }
  509. template <>
  510. DS_D_INLINE uint32_t to(float val)
  511. {
  512. return __float2uint_rn(val);
  513. }
  514. template <>
  515. DS_D_INLINE uint32_t to(__half val)
  516. {
  517. return __half2uint_rn(val);
  518. }
  519. // No direct support for integer casts at the C++ level and I don't feel they're so important
  520. // to demand an PTX at this time
  521. #ifdef BF16_AVAILABLE
  522. template <>
  523. DS_D_INLINE uint32_t to(__nv_bfloat16 val)
  524. {
  525. return __bfloat162uint_rn(val);
  526. }
  527. #endif
  528. /********************* To UINT16_T Conversions *********************/
  529. template <>
  530. DS_D_INLINE uint16_t to(double val)
  531. {
  532. return __double2uint_rn(val);
  533. }
  534. template <>
  535. DS_D_INLINE uint16_t to(float val)
  536. {
  537. return __float2uint_rn(val);
  538. }
  539. template <>
  540. DS_D_INLINE uint16_t to(__half val)
  541. {
  542. return __half2uint_rn(val);
  543. }
  544. // No direct support for integer casts at the C++ level and I don't feel they're so important
  545. // to demand an PTX at this time
  546. #ifdef BF16_AVAILABLE
  547. template <>
  548. DS_D_INLINE uint16_t to(__nv_bfloat16 val)
  549. {
  550. return __bfloat162uint_rn(val);
  551. }
  552. #endif
  553. /********************* To UINT8_T Conversions *********************/
  554. template <>
  555. DS_D_INLINE uint8_t to(double val)
  556. {
  557. return __double2uint_rn(val);
  558. }
  559. template <>
  560. DS_D_INLINE uint8_t to(float val)
  561. {
  562. return __float2uint_rn(val);
  563. }
  564. template <>
  565. DS_D_INLINE uint8_t to(__half val)
  566. {
  567. return __half2uint_rn(val);
  568. }
  569. // No direct support for integer casts at the C++ level and I don't feel they're so important
  570. // to demand an PTX at this time
  571. #ifdef BF16_AVAILABLE
  572. template <>
  573. DS_D_INLINE uint8_t to(__nv_bfloat16 val)
  574. {
  575. return __bfloat162uint_rn(val);
  576. }
  577. #endif
  578. } // namespace conversion