multi_tensor_apply.cuh 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Copyright NVIDIA/apex
  6. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  7. */
  8. #include <ATen/ATen.h>
  9. #include <ATen/AccumulateType.h>
  10. #include <ATen/cuda/CUDAContext.h>
  11. #include <ATen/cuda/Exceptions.h>
  12. #include <c10/cuda/CUDAGuard.h>
  13. #include "compat.h"
  14. #include <assert.h>
  15. // #include <iostream>
  16. // This header is the one-stop shop for all your multi-tensor apply needs.
  17. // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
  18. constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
  19. constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
  20. template <int n>
  21. struct TensorListMetadata {
  22. void* addresses[n][depth_to_max_tensors[n - 1]];
  23. int sizes[depth_to_max_tensors[n - 1]];
  24. unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
  25. int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
  26. int start_tensor_this_launch;
  27. };
  28. template <typename T, typename U, typename... ArgTypes>
  29. __global__ void multi_tensor_apply_kernel(int chunk_size,
  30. volatile int* noop_flag,
  31. T tl,
  32. U callable,
  33. ArgTypes... args)
  34. {
  35. // Hand the chunk information to the user-supplied functor to process however it likes.
  36. callable(chunk_size, noop_flag, tl, args...);
  37. }
  38. template <int depth, typename T, typename... ArgTypes>
  39. void multi_tensor_apply(int block_size,
  40. int chunk_size,
  41. const at::Tensor& noop_flag,
  42. const std::vector<std::vector<at::Tensor>>& tensor_lists,
  43. T callable,
  44. ArgTypes... args)
  45. {
  46. TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
  47. int len0 = tensor_lists[0].size();
  48. TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
  49. auto ref_device = tensor_lists[0][0].device();
  50. TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
  51. for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
  52. {
  53. TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
  54. for (int t = 0; t < tensor_lists[l].size(); t++) {
  55. // TODO: Print which tensor fails.
  56. bool contiguous_memory = tensor_lists[l][t].is_contiguous();
  57. #ifdef VERSION_GE_1_5
  58. contiguous_memory = (contiguous_memory ||
  59. tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
  60. #endif
  61. TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
  62. TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
  63. "A tensor was not on the same device as the first tensor");
  64. TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
  65. }
  66. }
  67. int ntensors = tensor_lists[0].size();
  68. TensorListMetadata<depth> tl;
  69. const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
  70. auto stream = at::cuda::getCurrentCUDAStream();
  71. tl.start_tensor_this_launch = 0;
  72. int loc_block_info = 0;
  73. int loc_tensor_info = 0;
  74. for (int t = 0; t < ntensors; t++) {
  75. tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
  76. for (int d = 0; d < depth; d++)
  77. tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
  78. loc_tensor_info++;
  79. int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
  80. for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
  81. // std::cout << chunks_this_tensor << std::endl;
  82. tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
  83. tl.block_to_chunk[loc_block_info] = chunk;
  84. loc_block_info++;
  85. bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
  86. chunk == chunks_this_tensor - 1);
  87. bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
  88. bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
  89. if (tensors_full || blocks_full || last_chunk) {
  90. // using accscalar_t = acc_type<scalar_t, true>;
  91. multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
  92. chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
  93. AT_CUDA_CHECK(cudaGetLastError());
  94. // Reset. The control flow possibilities here make my brain hurt.
  95. loc_block_info = 0;
  96. if (chunk == chunks_this_tensor - 1) {
  97. // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
  98. // std::endl;
  99. loc_tensor_info = 0;
  100. tl.start_tensor_this_launch = t + 1;
  101. } else {
  102. // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
  103. // std::endl;
  104. tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
  105. for (int d = 0; d < depth; d++)
  106. tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
  107. loc_tensor_info = 1;
  108. tl.start_tensor_this_launch = t;
  109. }
  110. }
  111. }
  112. }
  113. }