123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- /*
- Copyright 2022 The Microsoft DeepSpeed Team
- */
- #include "conversion_utils.h"
- #include "ds_kernel_utils.h"
- #include "quantization.h"
- #include "quantization_utils.h"
- namespace cg = cooperative_groups;
- #pragma once
- namespace dequantize {
- using Type = quantize::Type;
- template <Type qType, int numBits>
- using Params = quantize::Params<qType, numBits>;
- constexpr int granularity = quantize::granularity;
- using PackedInt4 = quantize::PackedInt4;
- constexpr int h_per_chunk = granularity / sizeof(__half);
- constexpr int h2_per_chunk = granularity / sizeof(__half2);
- /*
- Device function that reads quantized data from global memory, dequantizes
- it, and stores it to global memory.
- Template Arguments :
- numBits - Number of bits in quantized element. int: 4, 8
- qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
- unroll - Number of load steps to internally unroll int
- threads - Number of threads to perform dequant int
- Function arguments:
- global_output - __half pointer in global memory
- data - Quantized data in global memory
- global_params - Quantization parameters in global memory
- elems_per_group - Number of elements in each quantization group
- total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
- */
- template <int numBits, Type qType, int unroll, int threads>
- DS_D_INLINE void to_global(__half* global_output,
- const int8_t* data,
- const float* global_params,
- const int elems_per_group,
- const int total_elems);
- /*
- Device function that quantizes 16 bytes of __half type input data.
- Template Arguments :
- numBits - Number of bits in quantized element. int : 8 or 4
- qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
- Function Arguments :
- local_output - Local array to store dequantized data __half* or __half2*
- data - Pointer to quantized input data. int8_t*
- Params - Parameters for quantization. Params<qType, numBits>
- */
- template <int numBits, Type qType>
- DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params);
- template <int numBits, Type qType>
- DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params);
- /**************** Implementations ******************/
- template <int numBits, Type qType>
- DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params)
- {
- constexpr int32_t num_elems_packed = 8 / numBits;
- constexpr int32_t iters = h_per_chunk / num_elems_packed;
- #pragma unroll
- for (int i = 0; i < iters; i++) {
- if constexpr (num_elems_packed == 1) {
- local_output[i] = q_params.dequantize(data[i]);
- } else {
- auto accessible_data = *(PackedInt4*)(&data[i]);
- local_output[2 * i] = q_params.dequantize(accessible_data.low);
- local_output[2 * i + 1] = q_params.dequantize(accessible_data.high);
- }
- }
- }
- template <int numBits, Type qType>
- DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params)
- {
- __half* local_output_cast = reinterpret_cast<__half*>(local_output);
- chunk<numBits>(local_output_cast, data, q_params);
- }
- template <int numBits, Type qType, int unroll, int threads>
- DS_D_INLINE void _to_global(__half* global_output,
- const int8_t* data,
- const float* global_params,
- const int elems_per_group,
- const int total_elems)
- {
- cg::thread_block tb = cg::this_thread_block();
- cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
- // Load constants
- // TODO(cmikeh2): Refactor into functions?
- constexpr int load_granularity = granularity * numBits / 16;
- constexpr int load_step_stride = load_granularity * threads;
- constexpr int load_block_stride = load_step_stride * unroll;
- // Store constants
- constexpr int store_step_stride = h_per_chunk * threads;
- constexpr int store_block_stride = store_step_stride * unroll;
- // Load offsets
- const int load_block_offset = tb.group_index().x * load_block_stride;
- // Note: we can use `load_granularity` since the dtype is `int8_t`.
- const int load_thread_offset = tb.thread_index().x * load_granularity;
- const int8_t* load_base = data + load_block_offset + load_thread_offset;
- // Store offsets
- const int store_block_offset = tb.group_index().x * store_block_stride;
- const int store_thread_offset = tb.thread_index().x * h_per_chunk;
- const int elem_id_base = store_block_offset + store_thread_offset;
- int8_t local_load_buffer[load_granularity * unroll];
- __half local_dequant_buffer[h_per_chunk * unroll];
- /*
- Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
- totally clear to me, so this is a deliberately weird code structure.
- */
- #pragma unroll
- for (int i = 0; i < unroll; i++) {
- const int elem_id_iter = elem_id_base + i * store_step_stride;
- if (elem_id_iter < total_elems) {
- mem_access::load_global<load_granularity>(local_load_buffer + i * load_granularity,
- load_base + i * load_step_stride);
- }
- }
- #pragma unroll
- for (int i = 0; i < unroll; i++) {
- const int elem_id_iter = elem_id_base + i * store_step_stride;
- if (elem_id_iter < total_elems) {
- // TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
- // use indexing math to do division free interpolation of the successive groups?
- const int group_index = elem_id_iter / elems_per_group;
- Params<qType, numBits> q_params(global_params, group_index);
- chunk<numBits, qType>(local_dequant_buffer + i * h_per_chunk,
- local_load_buffer + i * load_granularity,
- q_params);
- mem_access::store_global<granularity>(global_output + elem_id_iter,
- local_dequant_buffer + i * h_per_chunk);
- }
- }
- }
- template <int numBits, Type qType, int unroll, int threads>
- DS_D_INLINE void to_global(__half* global_output,
- const int8_t* data,
- const float* global_params,
- const int elems_per_group,
- const int total_elems)
- {
- if constexpr (numBits == 4 || numBits == 8) {
- _to_global<numBits, qType, unroll, threads>(
- global_output, data, global_params, elems_per_group, total_elems);
- } else if constexpr (numBits == 3) {
- // TODO(cmikeh2): Need this implementation
- assert(false);
- } else {
- assert(false);
- }
- }
- } // namespace dequantize
|