// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include #include "custom_cuda_layers.h" #include "memory_access_utils.h" namespace cg = cooperative_groups; namespace td_sort { constexpr int threads = 512; constexpr int granularity = 16; constexpr int mem_vals = granularity / sizeof(int32_t); constexpr int max_buffer_size = (threads + 1) * mem_vals; #ifdef __HIP_PLATFORM_HCC__ constexpr int warp_size = 64; #else constexpr int warp_size = 32; #endif constexpr int max_warps = threads / warp_size; } // namespace td_sort template __global__ void scan_sort(int32_t* data, int reserved_tokens, int original_tokens) { cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); __shared__ int32_t indices_buffer[td_sort::max_buffer_size]; __shared__ int32_t intermediate_buffer[td_sort::max_warps]; __shared__ int32_t sorted_indices_buffer[td_sort::max_buffer_size]; for (int i = tb.thread_index().x * td_sort::mem_vals; i < original_tokens + 1; i += tb.group_dim().x * td_sort::mem_vals) { uint32_t zeros[td_sort::mem_vals] = {0, 0, 0, 0}; mem_access::store_shared(indices_buffer + i, zeros); } int32_t local_vals[VALS_PER_THREAD]; // We flatten layers/batch into a single indexing dimension int32_t* data_block = data + tb.group_index().x * reserved_tokens; // The next two loops really could be fused for a more logical code layout, but don't want to // move the barrier forward #pragma unroll for (int i = 0; i < VALS_PER_THREAD; i++) { const int iter_idx = i * td_sort::threads + tb.thread_index().x; if (iter_idx < reserved_tokens) { mem_access::load_global(local_vals + i, data_block + iter_idx); } else { local_vals[i] = 0; } } tb.sync(); #pragma unroll for (int i = 0; i < VALS_PER_THREAD; i++) { const int iter_idx = i * td_sort::threads + tb.thread_index().x; if (iter_idx < reserved_tokens) { const int32_t one = 1; mem_access::store_shared(indices_buffer + local_vals[i], &one); } } tb.sync(); int32_t local_input[td_sort::mem_vals]; mem_access::load_shared( local_input, indices_buffer + tb.thread_index().x * td_sort::mem_vals); int32_t reduce_vals[td_sort::mem_vals]; reduce_vals[0] = local_input[0]; #pragma unroll for (int i = 1; i < td_sort::mem_vals; i++) { reduce_vals[i] = local_input[i] + reduce_vals[i - 1]; } int32_t step_1_val = reduce_vals[td_sort::mem_vals - 1]; // Short span exclusive scan algorithm (less work efficient) #pragma unroll for (int i = 1; i < td_sort::warp_size; i *= 2) { int32_t step_val = warp.shfl_up(step_1_val, i); step_1_val = (warp.thread_rank() < i) ? step_1_val : step_1_val + step_val; } if (warp.thread_rank() == td_sort::warp_size - 1) { mem_access::store_shared(intermediate_buffer + warp.meta_group_rank(), &step_1_val); } tb.sync(); if (warp.meta_group_rank() == 0) { int32_t step_2_val = 0; if (warp.thread_rank() < td_sort::max_warps) { mem_access::load_shared(&step_2_val, intermediate_buffer + warp.thread_rank()); } #pragma unroll for (int i = 1; i < td_sort::warp_size; i *= 2) { int32_t step_val = warp.shfl_up(step_2_val, i); step_2_val = (warp.thread_rank() < i) ? step_2_val : step_2_val + step_val; } if (warp.thread_rank() < td_sort::max_warps) { mem_access::store_shared(intermediate_buffer + warp.thread_rank(), &step_2_val); } } tb.sync(); int step_2_val = 0; if (warp.meta_group_rank() > 0) { mem_access::load_shared(&step_2_val, intermediate_buffer + warp.meta_group_rank() - 1); } const int thread_offset = reduce_vals[td_sort::mem_vals - 1]; #pragma unroll for (int i = 0; i < td_sort::mem_vals; i++) { reduce_vals[i] += step_1_val + step_2_val - thread_offset; } mem_access::store_shared( indices_buffer + tb.thread_index().x * td_sort::mem_vals, reduce_vals); if (tb.thread_index().x == 0) { indices_buffer[original_tokens] = original_tokens - indices_buffer[original_tokens]; } tb.sync(); for (int i = 0; i < VALS_PER_THREAD; i++) { const int iter_idx = i * td_sort::threads + tb.thread_index().x; if (iter_idx < reserved_tokens) { if (local_vals[i] == 0) { int zero = 0; mem_access::store_shared(sorted_indices_buffer, &zero); } else { int sorted_idx; mem_access::load_shared(&sorted_idx, indices_buffer + local_vals[i] - 1); mem_access::store_shared(sorted_indices_buffer + sorted_idx, local_vals + i); } } } tb.sync(); #pragma unroll for (int i = 0; i < VALS_PER_THREAD; i++) { const int iter_idx = i * td_sort::threads + tb.thread_index().x; if (iter_idx < reserved_tokens) { int32_t store_val; mem_access::load_shared(&store_val, sorted_indices_buffer + iter_idx); mem_access::store_global(data_block + iter_idx, &store_val); } } } void launch_token_sort(int32_t* indices, int layers, int batch_size, int reserved_size, int original_tokens, cudaStream_t stream) { // Each sort is completely independent, can flatten this dimension dim3 grid(layers * batch_size); dim3 block(td_sort::threads); const int vals_per_thread = (reserved_size + td_sort::threads - 1) / td_sort::threads; if (vals_per_thread == 1) { scan_sort<1><<>>(indices, reserved_size, original_tokens); } else if (vals_per_thread == 2) { scan_sort<2><<>>(indices, reserved_size, original_tokens); } else if (vals_per_thread == 3) { scan_sort<3><<>>(indices, reserved_size, original_tokens); } else if (vals_per_thread == 4) { scan_sort<4><<>>(indices, reserved_size, original_tokens); } else { assert(false); } }