123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625 |
- /*
- Copyright 2022 The Microsoft DeepSpeed Team
- */
- #pragma once
- #include "ds_kernel_utils.h"
- #include <cuda_fp16.h>
- #include <stdint.h>
- #ifdef BF16_AVAILABLE
- #include <cuda_bf16.h>
- #endif
- namespace conversion {
- // Basic primitive for constructing conversions
- template <typename TO, typename FROM>
- DS_D_INLINE TO to(FROM val)
- {
- return to(val);
- }
- // Specializations
- /********************* Identity Conversions *********************/
- /*
- Identity conversions are useful in templated functions where we might have
- a fixed destination type. For example, I might have a kernel that accepts
- __half, __nv_bfloat16, and float but always want to do the core computation
- at floating point:
- T mem_value = input[idx];
- float compute_value = conversion::to<float, T>(mem_value);
- In practice, we should be able to elide the second template parameter:
- float compute_val = conversion::to<float>(mem_value);
- In this case, we need an implementation to handle the T = float case
- NOTE: The type inferencing system appears to be unable to handle inferring the first
- template parameter, even in the trivial case.
- */
- // Floating point types
- template <>
- DS_D_INLINE double to(double val)
- {
- return val;
- }
- template <>
- DS_D_INLINE float to(float val)
- {
- return val;
- }
- template <>
- DS_D_INLINE __half to(__half val)
- {
- return val;
- }
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
- {
- return val;
- }
- #endif
- // Integer types
- template <>
- DS_D_INLINE int8_t to(int8_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE uint8_t to(uint8_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE int16_t to(int16_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE uint16_t to(uint16_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE int32_t to(int32_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE uint32_t to(uint32_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE int64_t to(int64_t val)
- {
- return val;
- }
- template <>
- DS_D_INLINE uint64_t to(uint64_t val)
- {
- return val;
- }
- // TODO: evaluate if we want bools
- /********************* To Double Conversions *********************/
- // * to double variants
- // Would normally like to not use C cast, but this is an important enough conversion
- // to keep
- template <>
- DS_D_INLINE double to(float val)
- {
- #ifdef PTX_AVAILABLE
- double ret_val;
- asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val));
- return ret_val;
- #else
- return double(val);
- #endif
- }
- // Note: there is a CVT instruction for __half -> double, but there's no inline interface
- // for passing a single half value
- template <>
- DS_D_INLINE double to(__half val)
- {
- return to<double>(__half2float(val));
- }
- template <>
- DS_D_INLINE double to(int64_t val)
- {
- return __ll2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(int32_t val)
- {
- return __int2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(int16_t val)
- {
- return __int2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(int8_t val)
- {
- return __int2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(uint64_t val)
- {
- return __ull2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(uint32_t val)
- {
- return __uint2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(uint16_t val)
- {
- return __uint2double_rn(val);
- }
- template <>
- DS_D_INLINE double to(uint8_t val)
- {
- return __uint2double_rn(val);
- }
- // Same applies here
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE double to(__nv_bfloat16 val)
- {
- return to<double>(__bfloat162float(val));
- }
- #endif
- /********************* To Float Conversions *********************/
- template <>
- DS_D_INLINE float to(double val)
- {
- return __double2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(__half val)
- {
- return __half2float(val);
- }
- template <>
- DS_D_INLINE float to(int64_t val)
- {
- return __ll2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(int32_t val)
- {
- return __int2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(int16_t val)
- {
- return __int2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(int8_t val)
- {
- return __int2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(uint64_t val)
- {
- return __ull2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(uint32_t val)
- {
- return __uint2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(uint16_t val)
- {
- return __uint2float_rn(val);
- }
- template <>
- DS_D_INLINE float to(uint8_t val)
- {
- return __uint2float_rn(val);
- }
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE float to(__nv_bfloat16 val)
- {
- return __bfloat162float(val);
- }
- #endif
- /********************* To Float2 Conversions *********************/
- template <>
- DS_D_INLINE float2 to(__half2 val)
- {
- return __half22float2(val);
- }
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE float2 to(__nv_bfloat162 val)
- {
- return __bfloat1622float2(val);
- }
- #endif
- /********************* To Half Conversions *********************/
- template <>
- DS_D_INLINE __half to(double val)
- {
- return __double2half(val);
- }
- template <>
- DS_D_INLINE __half to(float val)
- {
- return __float2half(val);
- }
- template <>
- DS_D_INLINE __half to(int64_t val)
- {
- return __ll2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(int32_t val)
- {
- return __int2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(int16_t val)
- {
- return __short2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(int8_t val)
- {
- return __int2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(uint64_t val)
- {
- return __ull2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(uint32_t val)
- {
- return __uint2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(uint16_t val)
- {
- return __ushort2half_rn(val);
- }
- template <>
- DS_D_INLINE __half to(uint8_t val)
- {
- return __uint2half_rn(val);
- }
- #ifdef BF16_AVAILABLE
- // No direct conversion
- template <>
- DS_D_INLINE __half to(__nv_bfloat16 val)
- {
- return to<__half>(to<float>(val));
- }
- #endif
- /********************* To Half2 Conversions *********************/
- template <>
- DS_D_INLINE __half2 to(float2 val)
- {
- return __float22half2_rn(val);
- }
- #ifdef BF16_AVAILABLE
- // No direct conversion
- template <>
- DS_D_INLINE __half2 to(__nv_bfloat162 val)
- {
- return to<__half2>(to<float2>(val));
- }
- #endif
- /********************* To BF16 Conversions *********************/
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE __nv_bfloat16 to(double val)
- {
- return __double2bfloat16(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(float val)
- {
- return __float2bfloat16(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(int64_t val)
- {
- return __ll2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(int32_t val)
- {
- return __int2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(int16_t val)
- {
- return __short2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(int8_t val)
- {
- return __int2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(uint64_t val)
- {
- return __ull2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(uint32_t val)
- {
- return __uint2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(uint16_t val)
- {
- return __ushort2bfloat16_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat16 to(uint8_t val)
- {
- return __uint2bfloat16_rn(val);
- }
- #endif
- /********************* To BF162 Conversions *********************/
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE __nv_bfloat162 to(float2 val)
- {
- return __float22bfloat162_rn(val);
- }
- template <>
- DS_D_INLINE __nv_bfloat162 to(__half2 val)
- {
- return to<__nv_bfloat162>(to<float2>(val));
- }
- #endif
- /********************* To INT64_T Conversions *********************/
- template <>
- DS_D_INLINE int64_t to(double val)
- {
- return __double2ll_rn(val);
- }
- template <>
- DS_D_INLINE int64_t to(float val)
- {
- return __float2ll_rn(val);
- }
- template <>
- DS_D_INLINE int64_t to(__half val)
- {
- return __half2ll_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE int64_t to(__nv_bfloat16 val)
- {
- return __bfloat162ll_rn(val);
- }
- #endif
- /********************* To INT32_T Conversions *********************/
- template <>
- DS_D_INLINE int32_t to(double val)
- {
- return __double2int_rn(val);
- }
- template <>
- DS_D_INLINE int32_t to(float val)
- {
- return __float2int_rn(val);
- }
- template <>
- DS_D_INLINE int32_t to(__half val)
- {
- return __half2int_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE int32_t to(__nv_bfloat16 val)
- {
- return __bfloat162int_rn(val);
- }
- #endif
- /********************* To INT16_T Conversions *********************/
- template <>
- DS_D_INLINE int16_t to(double val)
- {
- return __double2int_rn(val);
- }
- template <>
- DS_D_INLINE int16_t to(float val)
- {
- return __float2int_rn(val);
- }
- template <>
- DS_D_INLINE int16_t to(__half val)
- {
- return __half2int_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE int16_t to(__nv_bfloat16 val)
- {
- return __bfloat162int_rn(val);
- }
- #endif
- /********************* To INT8_T Conversions *********************/
- template <>
- DS_D_INLINE int8_t to(double val)
- {
- return __double2int_rn(val);
- }
- template <>
- DS_D_INLINE int8_t to(float val)
- {
- return __float2int_rn(val);
- }
- template <>
- DS_D_INLINE int8_t to(__half val)
- {
- return __half2int_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE int8_t to(__nv_bfloat16 val)
- {
- return __bfloat162int_rn(val);
- }
- #endif
- /********************* To UINT64_T Conversions *********************/
- template <>
- DS_D_INLINE uint64_t to(double val)
- {
- return __double2ull_rn(val);
- }
- template <>
- DS_D_INLINE uint64_t to(float val)
- {
- return __float2ull_rn(val);
- }
- template <>
- DS_D_INLINE uint64_t to(__half val)
- {
- return __half2ull_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE uint64_t to(__nv_bfloat16 val)
- {
- return __bfloat162ull_rn(val);
- }
- #endif
- /********************* To UINT32_T Conversions *********************/
- template <>
- DS_D_INLINE uint32_t to(double val)
- {
- return __double2uint_rn(val);
- }
- template <>
- DS_D_INLINE uint32_t to(float val)
- {
- return __float2uint_rn(val);
- }
- template <>
- DS_D_INLINE uint32_t to(__half val)
- {
- return __half2uint_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE uint32_t to(__nv_bfloat16 val)
- {
- return __bfloat162uint_rn(val);
- }
- #endif
- /********************* To UINT16_T Conversions *********************/
- template <>
- DS_D_INLINE uint16_t to(double val)
- {
- return __double2uint_rn(val);
- }
- template <>
- DS_D_INLINE uint16_t to(float val)
- {
- return __float2uint_rn(val);
- }
- template <>
- DS_D_INLINE uint16_t to(__half val)
- {
- return __half2uint_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE uint16_t to(__nv_bfloat16 val)
- {
- return __bfloat162uint_rn(val);
- }
- #endif
- /********************* To UINT8_T Conversions *********************/
- template <>
- DS_D_INLINE uint8_t to(double val)
- {
- return __double2uint_rn(val);
- }
- template <>
- DS_D_INLINE uint8_t to(float val)
- {
- return __float2uint_rn(val);
- }
- template <>
- DS_D_INLINE uint8_t to(__half val)
- {
- return __half2uint_rn(val);
- }
- // No direct support for integer casts at the C++ level and I don't feel they're so important
- // to demand an PTX at this time
- #ifdef BF16_AVAILABLE
- template <>
- DS_D_INLINE uint8_t to(__nv_bfloat16 val)
- {
- return __bfloat162uint_rn(val);
- }
- #endif
- } // namespace conversion
|