slice_attn_masks.cu 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. #include "memory_access_utils.h"
  6. namespace cg = cooperative_groups;
  7. template <typename T>
  8. __global__ void slice_gpt_mask_impl(T* output_mask,
  9. const T* input_mask,
  10. int truncated_seq_len,
  11. int orig_seq_len)
  12. {
  13. const int in_batch_stride = orig_seq_len * orig_seq_len;
  14. const int out_batch_stride = truncated_seq_len * truncated_seq_len;
  15. cg::thread_block tb = cg::this_thread_block();
  16. const T* input_mask_block =
  17. input_mask + blockIdx.x * in_batch_stride + blockIdx.y * orig_seq_len;
  18. T* output_mask_block =
  19. output_mask + blockIdx.x * out_batch_stride + blockIdx.y * truncated_seq_len;
  20. for (int i = tb.thread_index().x; i < truncated_seq_len; i += blockDim.x) {
  21. output_mask_block[i] = input_mask_block[i];
  22. }
  23. }
  24. template <typename T>
  25. void launch_slice_gpt_mask(T* output_mask,
  26. const T* input_mask,
  27. int batch_size,
  28. int truncated_seq_len,
  29. int orig_seq_len,
  30. cudaStream_t stream)
  31. {
  32. const int threads = (truncated_seq_len >= 1024) ? 1024 : truncated_seq_len;
  33. dim3 block(threads);
  34. dim3 grid(batch_size, truncated_seq_len);
  35. slice_gpt_mask_impl<T>
  36. <<<grid, block, 0, stream>>>(output_mask, input_mask, truncated_seq_len, orig_seq_len);
  37. }
  38. template void launch_slice_gpt_mask<float>(float*, const float*, int, int, int, cudaStream_t);
  39. template void launch_slice_gpt_mask<__half>(__half*, const __half*, int, int, int, cudaStream_t);
  40. template <typename T>
  41. __global__ void slice_bert_mask_impl(T* output_mask,
  42. const T* input_mask,
  43. const int32_t* retained_indices,
  44. int32_t truncated_seq_len,
  45. int32_t orig_seq_len)
  46. {
  47. const int in_batch_stride = orig_seq_len * orig_seq_len;
  48. const int out_batch_stride = truncated_seq_len * truncated_seq_len;
  49. const int out_layer_stride = out_batch_stride * gridDim.y;
  50. cg::thread_block tb = cg::this_thread_block();
  51. const int out_layer_offset = tb.group_index().x * out_layer_stride;
  52. const int in_batch_offset = tb.group_index().y * in_batch_stride;
  53. const int out_batch_offset = tb.group_index().y * out_batch_stride;
  54. const int32_t gather_row =
  55. retained_indices[tb.group_index().y * truncated_seq_len + tb.group_index().z];
  56. const int in_seq_offset = gather_row * orig_seq_len;
  57. const int out_seq_offset = tb.group_index().z * truncated_seq_len;
  58. const T* in_sequence = input_mask + in_batch_offset + in_seq_offset;
  59. T* out_sequence = output_mask + out_layer_offset + out_batch_offset + out_seq_offset;
  60. const int32_t* gather_data = retained_indices + tb.group_index().y * truncated_seq_len;
  61. for (int i = tb.thread_index().x; i < truncated_seq_len; i += blockDim.x) {
  62. out_sequence[i] = in_sequence[gather_data[i]];
  63. }
  64. }
  65. /*
  66. Since the Bert mask is not causal like GPT, we can't just generate a set of
  67. masks for the entire model based off a single layer sample.
  68. We map the kernel as follows:
  69. z-dimension: layer
  70. y-dimension: batch
  71. x-dimension: sequence_offset
  72. */
  73. template <typename T>
  74. void launch_slice_bert_mask(T* output_mask,
  75. const T* input_mask,
  76. const int32_t* retained_indices,
  77. int32_t layers,
  78. int32_t batch_size,
  79. int32_t truncated_seq_len,
  80. int32_t orig_seq_len,
  81. cudaStream_t stream)
  82. {
  83. const int threads = (truncated_seq_len >= 1024) ? 1024 : truncated_seq_len;
  84. dim3 block(threads);
  85. dim3 grid(layers, batch_size, truncated_seq_len);
  86. slice_bert_mask_impl<T><<<grid, block, 0, stream>>>(
  87. output_mask, input_mask, retained_indices, truncated_seq_len, orig_seq_len);
  88. }
  89. template void launch_slice_bert_mask<float>(float*,
  90. const float*,
  91. const int32_t*,
  92. int32_t,
  93. int32_t,
  94. int32_t,
  95. int32_t,
  96. cudaStream_t);
  97. template void launch_slice_bert_mask<__half>(__half*,
  98. const __half*,
  99. const int32_t*,
  100. int32_t,
  101. int32_t,
  102. int32_t,
  103. int32_t,
  104. cudaStream_t);