// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #pragma once #include "conversion_utils.h" #include "ds_kernel_utils.h" #include "memory_access_utils.h" namespace cg = cooperative_groups; namespace reduce { enum class ROpType { // Addition Add, // Maximum reduction Max, // Minimum reduction Min, }; constexpr int max_threads = 1024; constexpr int max_warps = max_threads / hw_warp_size; /* High level API. The API takes in a set of operations and variables and performs that reduction operation on that variable. The reductions of each of the arguments are completely independent of each other ( i.e., the val1-op1 combination has no impact on val2-op2). Example usage: ``` cpp float max_val; float min_val; reduce::block(tb, warp, max_val, min_val); ``` TODO(cmikeh2): In theory, we might be able to do this sequentially with device functions and rely on the assembler correctly behaving. My initial instinct is this won't work, but if it does it would reduce implementation cost significantly. TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic currently supports this (more incidentally than anything else). It is not uncommon in something like softmax or a fused attention kernel to map multiple reductions to a thread block, but each reduction itself is only scoped to part of the threads (i.e block size = 512, 128 threads per reduction). */ template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val); template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2); template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3); template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3, float& val4); /* The partitioned block is a special case of the above where in the warps of a threadblock are partitioned into separate independent reductions. For example, I might have an 8 warp thread block in which each pair of warps is processing an independent piece of data. I would then reduce that data with the something like the following: ``` cpp float max_val; reduce::partitioned_block(tb, warp, max_val); ``` After which, each pair of warps would have coherent data with each other. Note, this API will not provide correct results if the number of warps per partition is not a power of 2. */ template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val); template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2); template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3); template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3, float& val4); /* Single element reduction primitives. Used inside serial collection loops. Example usage: using rop = reduce::OpType; float min = init(); for (int i = 0; i < 4; i++) { min = reduce::element(min, data[i]); } */ template DS_D_INLINE T element(const T lhs, const T rhs); template DS_D_INLINE T init(); /********************** Internal reduction APIs **********************/ /* Single element "reductions". TODO(cmikeh2): this sort of "op" concept should be refactored into its own implementation at some point. This interface may be easily expanded for new types/operations, but the typical reductions we need are covered with min/max/add on float. NOTE: there is no mean reduction because that relies on knowledge of how many values were already reduced into each scalar. Implementing this on top of reduce should be straightforward (can just wrap the sum reduction) and would be a good extension of the header. */ DS_D_INLINE int _warp_rank() { const int thread_rank = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; return thread_rank / hw_warp_size; } /* Float element reduce implementations */ template <> DS_D_INLINE float element(const float lhs, const float rhs) { return lhs + rhs; } template <> DS_D_INLINE float element(const float lhs, const float rhs) { return fmaxf(lhs, rhs); } template <> DS_D_INLINE float element(const float lhs, const float rhs) { return fminf(lhs, rhs); } /* __half element reduce implementation */ template <> DS_D_INLINE __half element(const __half lhs, const __half rhs) { return lhs + rhs; } template <> DS_D_INLINE __half element(const __half lhs, const __half rhs) { #if __CUDA_ARCH__ >= 800 // Intrinsic limited to Ampere + newer return __hmax(lhs, rhs); #else return (lhs > rhs) ? lhs : rhs; #endif } template <> DS_D_INLINE __half element(const __half lhs, const __half rhs) { #if __CUDA_ARCH__ >= 800 // Intrinsic limited to Ampere + newer return __hmin(lhs, rhs); #else return (lhs < rhs) ? lhs : rhs; #endif } /* __half2 element reduce implementation */ template <> DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) { return lhs + rhs; } template <> DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) { #if __CUDA_ARCH__ >= 800 return __hmax2(lhs, rhs); #else __half2 ret_val; ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; return ret_val; #endif } template <> DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) { #if __CUDA_ARCH__ >= 800 return __hmin2(lhs, rhs); #else __half2 ret_val; ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x; ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y; return ret_val; #endif } /* Reduction initialization primitives */ template <> DS_D_INLINE float init() { return 0.0f; } template <> DS_D_INLINE float init() { // Positive infinity return INFINITY; } template <> DS_D_INLINE float init() { // Negative infinity return -INFINITY; } template <> DS_D_INLINE __half init() { constexpr __half_raw zero = {0x0000}; return __half(zero); } template <> DS_D_INLINE __half init() { constexpr __half_raw inf = {0x7C00}; return __half(inf); } template <> DS_D_INLINE __half init() { constexpr __half_raw neg_inf = {0xFC00}; return __half(neg_inf); } template <> DS_D_INLINE __half2 init() { #ifdef __HIP_PLATFORM_HCC__ return __half2{_Float16_2{0x0000, 0x0000}}; #else constexpr __half2_raw zero = {0x0000, 0x0000}; return __half2(zero); #endif } template <> DS_D_INLINE __half2 init() { #ifdef __HIP_PLATFORM_HCC__ return __half2{_Float16_2{0x7C00, 0x7C00}}; #else constexpr __half2_raw inf = {0x7C00, 0x7C00}; return __half2(inf); #endif } template <> DS_D_INLINE __half2 init() { #ifdef __HIP_PLATFORM_HCC__ return __half2{_Float16_2{0xFC00, 0xFC00}}; #else constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; return __half2(neg_inf); #endif } template DS_D_INLINE void init(T* data) { data[0] = init(); } template DS_D_INLINE void init(T* data) { data[0] = init(); data[1] = init(); } template DS_D_INLINE void init(T* data) { data[0] = init(); data[1] = init(); data[2] = init(); } template DS_D_INLINE void init(T* data) { data[0] = init(); data[1] = init(); data[2] = init(); data[3] = init(); } /* Warp reduction primitives `reduction_width` is an unsafe template parameter, that is that when using `reduction_width` < hw_warp_size the warp is partitioned into `hw_warp_size` / `reduction_width` groups of partial sums. If someone can figure out how to use variadic templates in a reasonable way here (fold is C++17 only and I don't think helps and recursion feels like huge overkill that harms readability) that would be wonderful. */ template DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], warp.shfl_xor(data[0], i)); } } template DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], warp.shfl_xor(data[0], i)); data[1] = element(data[1], warp.shfl_xor(data[1], i)); } } template DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], warp.shfl_xor(data[0], i)); data[1] = element(data[1], warp.shfl_xor(data[1], i)); data[2] = element(data[2], warp.shfl_xor(data[2], i)); } } template DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], warp.shfl_xor(data[0], i)); data[1] = element(data[1], warp.shfl_xor(data[1], i)); data[2] = element(data[2], warp.shfl_xor(data[2], i)); data[3] = element(data[3], warp.shfl_xor(data[3], i)); } } /* Implementation for primary block reduction that serves both `block` and `partitioned_block`. Total warps refers to the reduction width of the reduction, not the number of warps in the block (which may exceed that if the block is partitioned or if we do a conservative bound at compile time). */ template DS_D_INLINE void _block(cg::thread_block& tb, cg::thread_block_tile& warp_arg, float* data) { constexpr int elems = sizeof...(Ops); // Separated for now in case this no longer is true constexpr int bytes = sizeof(float); // Unused when `partition_size == 1` or total_warps == 1 __shared__ float reduce_buffer[max_warps * elems]; #ifdef __HIP_PLATFORM_HCC__ const int total_threads = blockDim.x * blockDim.y * blockDim.z; const int running_warps = total_threads / hw_warp_size; #else const int running_warps = warp_arg.meta_group_size(); #endif // Always perform warp-scope reduction _warp(warp_arg, data); // If max_warps == 1 let's skip the runtime check if (total_warps != 1) { if (warp_arg.thread_rank() == 0) { #pragma unroll for (int i = 0; i < elems; i++) { mem_access::store_shared(reduce_buffer + elems * _warp_rank() + i, data + i); } } // Synchronization inside block-uniform conditional is safe tb.sync(); if (_warp_rank() == 0) { if (warp_arg.thread_rank() < running_warps) { #pragma unroll for (int i = 0; i < elems; i++) { mem_access::load_shared( data + i, reduce_buffer + elems * warp_arg.thread_rank() + i); } } else { init(data); } _warp(warp_arg, data); #pragma unroll for (int i = 0; i < elems; i++) { mem_access::store_shared(reduce_buffer + elems * warp_arg.thread_rank() + i, data + i); } } // Synchronization inside block-uniform conditional is safe tb.sync(); #pragma unroll for (int i = 0; i < elems; i++) { mem_access::load_shared(data + i, reduce_buffer + _warp_rank() * elems + i); } } } /* Main API implementations. For the most part, they just convert the individual variables into arrays, which makes working with them easier with a single implementation. In theory, we could use the `_block` implementation as another option, but the nature of using a pointer is a little less safe and this allows us to obfuscate the details of the partitioned implementation. */ template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) { _block(tb, warp, &val); } template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2) { float data[2] = {val1, val2}; _block(tb, warp, data); val1 = data[0]; val2 = data[1]; } template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3) { float data[3] = {val1, val2, val3}; _block(tb, warp, data); val1 = data[0]; val2 = data[1]; val3 = data[2]; } template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3, float& val4) { float data[4] = {val1, val2, val3, val4}; _block(tb, warp, data); val1 = data[0]; val2 = data[1]; val3 = data[2]; val4 = data[3]; } /* Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order to shorten block scale reduction length. */ template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) { if (num_threads <= hw_warp_size) { _warp(warp, &val); } else { constexpr int num_warps = num_threads / hw_warp_size; _block(tb, warp, &val); } } template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2) { float data[2] = {val1, val2}; if (num_threads <= hw_warp_size) { _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; _block(tb, warp, data); } val1 = data[0]; val2 = data[1]; } template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3) { float data[3] = {val1, val2, val3}; if (num_threads <= hw_warp_size) { _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; _block(tb, warp, data); } val1 = data[0]; val2 = data[1]; val3 = data[2]; } template DS_D_INLINE void partitioned_block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val1, float& val2, float& val3, float& val4) { float data[4] = {val1, val2, val3, val4}; if (num_threads <= hw_warp_size) { _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; _block(tb, warp, data); } val1 = data[0]; val2 = data[1]; val3 = data[2]; val4 = data[3]; } } // namespace reduce