pt_binding.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. #include <vector>
  6. #include "custom_cuda_layers.h"
  7. torch::Tensor token_sort_(torch::Tensor& unsorted_token_ids, int64_t original_tokens)
  8. {
  9. const int layers = unsorted_token_ids.size(0);
  10. const int batch_size = unsorted_token_ids.size(1);
  11. const int reserved_tokens = unsorted_token_ids.size(2);
  12. launch_token_sort(unsorted_token_ids.data_ptr<int32_t>(),
  13. layers,
  14. batch_size,
  15. reserved_tokens,
  16. original_tokens,
  17. c10::cuda::getCurrentCUDAStream());
  18. return unsorted_token_ids;
  19. }
  20. torch::Tensor token_gather(torch::Tensor& activations,
  21. torch::Tensor& sorted_indices,
  22. bool batch_first)
  23. {
  24. // Activations may be in either [N, S, C] or [S, N, C] while sorted_indices is
  25. // always in [N, retained]
  26. /*
  27. TORCH_CHECK(sorted_indices.size(0) == activations.size(0) ||
  28. sorted_indices.size(0) == activations.size(1),
  29. "Unable to match the batch size of the sorted indices to the activation
  30. shape."); TORCH_CHECK(activations.size(2) % 8 == 0, "Channels must be divisible by 8 to align
  31. with vectorized loads.");
  32. */
  33. // bool batch_first = sorted_indices.size(0) == activations.size(0);
  34. const int64_t dim_0 = (batch_first) ? sorted_indices.size(0) : sorted_indices.size(1);
  35. const int64_t dim_1 = (batch_first) ? sorted_indices.size(1) : sorted_indices.size(0);
  36. const int64_t dim_2 = activations.size(2);
  37. auto output = torch::empty({dim_0, dim_1, dim_2}, activations.options());
  38. const int batch_size = sorted_indices.size(0);
  39. const int channels = dim_2;
  40. const int retained_tokens = sorted_indices.size(1);
  41. const int read_batch_stride = (batch_first) ? activations.stride(0) : activations.stride(1);
  42. const int read_seq_stride = (batch_first) ? activations.stride(1) : activations.stride(0);
  43. const int write_batch_stride = (batch_first) ? output.stride(0) : output.stride(1);
  44. const int write_seq_stride = (batch_first) ? output.stride(1) : output.stride(0);
  45. if (activations.options().dtype() == torch::kFloat) {
  46. launch_gather_tokens((float*)output.data_ptr(),
  47. (float*)activations.data_ptr(),
  48. (int32_t*)sorted_indices.data_ptr(),
  49. batch_size,
  50. retained_tokens,
  51. channels,
  52. read_batch_stride,
  53. read_seq_stride,
  54. write_batch_stride,
  55. write_seq_stride,
  56. c10::cuda::getCurrentCUDAStream());
  57. } else {
  58. launch_gather_tokens((__half*)output.data_ptr(),
  59. (__half*)activations.data_ptr(),
  60. (int32_t*)sorted_indices.data_ptr(),
  61. batch_size,
  62. retained_tokens,
  63. channels,
  64. read_batch_stride,
  65. read_seq_stride,
  66. write_batch_stride,
  67. write_seq_stride,
  68. c10::cuda::getCurrentCUDAStream());
  69. }
  70. return output;
  71. }
  72. torch::Tensor token_scatter_(torch::Tensor& all_activations,
  73. torch::Tensor& layer_activations,
  74. torch::Tensor& sorted_indices,
  75. bool batch_first)
  76. {
  77. // Activations may be in either [N, S, C] or [S, N, C] while sorted_indices is
  78. // always in [N, retained]
  79. /*
  80. TORCH_CHECK(sorted_indices.size(0) == all_activations.size(0) ||
  81. sorted_indices.size(0) == all_activations.size(1),
  82. "Unable to match the batch size of the sorted indices to the activation
  83. shape."); TORCH_CHECK(all_activations.size(2) % 8 != 0, "Channels must be divisible by 8 to
  84. align with vectorized loads.");
  85. */
  86. // bool batch_first = sorted_indices.size(0) == all_activations.size(0);
  87. const int batch_size = sorted_indices.size(0);
  88. const int channels = all_activations.size(2);
  89. const int retained_tokens = sorted_indices.size(1);
  90. const int read_batch_stride = (batch_first) ? layer_activations.stride(0)
  91. : layer_activations.stride(1);
  92. const int read_seq_stride = (batch_first) ? layer_activations.stride(1)
  93. : layer_activations.stride(0);
  94. const int write_batch_stride = (batch_first) ? all_activations.stride(0)
  95. : all_activations.stride(1);
  96. const int write_seq_stride = (batch_first) ? all_activations.stride(1)
  97. : all_activations.stride(0);
  98. if (all_activations.options().dtype() == torch::kFloat) {
  99. launch_scatter_tokens((float*)all_activations.data_ptr(),
  100. (float*)layer_activations.data_ptr(),
  101. (int32_t*)sorted_indices.data_ptr(),
  102. batch_size,
  103. retained_tokens,
  104. channels,
  105. read_batch_stride,
  106. read_seq_stride,
  107. write_batch_stride,
  108. write_seq_stride,
  109. c10::cuda::getCurrentCUDAStream());
  110. } else {
  111. launch_scatter_tokens((__half*)all_activations.data_ptr(),
  112. (__half*)layer_activations.data_ptr(),
  113. (int32_t*)sorted_indices.data_ptr(),
  114. batch_size,
  115. retained_tokens,
  116. channels,
  117. read_batch_stride,
  118. read_seq_stride,
  119. write_batch_stride,
  120. write_seq_stride,
  121. c10::cuda::getCurrentCUDAStream());
  122. }
  123. return all_activations;
  124. }
  125. torch::Tensor mask_gather_bert(torch::Tensor& dense_mask, torch::Tensor& sorted_indices)
  126. {
  127. // TORCH_CHECK(dense_mask.dim() == 4)
  128. const int batch_size = dense_mask.size(0);
  129. const int layers = sorted_indices.size(0);
  130. /*
  131. TORCH_CHECK(layers * batch_size == sorted_indices.size(0),
  132. "Mismatch between the indices and the mask");
  133. */
  134. const int orig_seq_len = dense_mask.size(3);
  135. const int truncated_seq_len = sorted_indices.size(2);
  136. auto output = torch::empty({layers, batch_size, 1, truncated_seq_len, truncated_seq_len},
  137. dense_mask.options());
  138. if (dense_mask.options().dtype() == torch::kFloat) {
  139. launch_slice_bert_mask((float*)output.data_ptr(),
  140. (const float*)dense_mask.data_ptr(),
  141. (const int32_t*)sorted_indices.data_ptr(),
  142. layers,
  143. batch_size,
  144. truncated_seq_len,
  145. orig_seq_len,
  146. c10::cuda::getCurrentCUDAStream());
  147. } else {
  148. launch_slice_bert_mask((__half*)output.data_ptr(),
  149. (const __half*)dense_mask.data_ptr(),
  150. (const int32_t*)sorted_indices.data_ptr(),
  151. layers,
  152. batch_size,
  153. truncated_seq_len,
  154. orig_seq_len,
  155. c10::cuda::getCurrentCUDAStream());
  156. }
  157. return output;
  158. }
  159. torch::Tensor mask_gather_gpt(torch::Tensor dense_mask, int truncated_seq_len)
  160. {
  161. // TORCH_CHECK(dense_mask.dim() == 4)
  162. const int batch_size = dense_mask.size(0);
  163. const int orig_seq_len = dense_mask.size(3);
  164. auto output =
  165. torch::empty({batch_size, 1, truncated_seq_len, truncated_seq_len}, dense_mask.options());
  166. if (dense_mask.options().dtype() == torch::kFloat) {
  167. launch_slice_gpt_mask((float*)output.data_ptr(),
  168. (const float*)dense_mask.data_ptr(),
  169. batch_size,
  170. truncated_seq_len,
  171. orig_seq_len,
  172. c10::cuda::getCurrentCUDAStream());
  173. } else {
  174. launch_slice_gpt_mask((__half*)output.data_ptr(),
  175. (const __half*)dense_mask.data_ptr(),
  176. batch_size,
  177. truncated_seq_len,
  178. orig_seq_len,
  179. c10::cuda::getCurrentCUDAStream());
  180. }
  181. return output;
  182. }
  183. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  184. {
  185. m.def("token_sort_", &token_sort_, "Comparison free sorting algorithm (CUDA)");
  186. m.def("token_gather", &token_gather, "Parallel gather of tokens (CUDA)");
  187. m.def("token_scatter_", &token_scatter_, "Parallel scatter of tokens (CUDA)");
  188. m.def("mask_gather_bert", &mask_gather_bert, "Token-based mask gather for BERT masking (CUDA)");
  189. m.def("mask_gather_gpt", &mask_gather_gpt, "Token-based mask gather for GPT masking (CUDA)");
  190. }