multi_tensor_apply.cuh 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. /* Copyright 2020 The Microsoft DeepSpeed Team
  2. Copyright NVIDIA/apex
  3. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  4. */
  5. #include <ATen/ATen.h>
  6. #include <ATen/AccumulateType.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <ATen/cuda/Exceptions.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #include "compat.h"
  11. #include <assert.h>
  12. // #include <iostream>
  13. // This header is the one-stop shop for all your multi-tensor apply needs.
  14. // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
  15. constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
  16. constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
  17. template <int n>
  18. struct TensorListMetadata {
  19. void* addresses[n][depth_to_max_tensors[n - 1]];
  20. int sizes[depth_to_max_tensors[n - 1]];
  21. unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
  22. int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
  23. int start_tensor_this_launch;
  24. };
  25. template <typename T, typename U, typename... ArgTypes>
  26. __global__ void multi_tensor_apply_kernel(int chunk_size,
  27. volatile int* noop_flag,
  28. T tl,
  29. U callable,
  30. ArgTypes... args)
  31. {
  32. // Hand the chunk information to the user-supplied functor to process however it likes.
  33. callable(chunk_size, noop_flag, tl, args...);
  34. }
  35. template <int depth, typename T, typename... ArgTypes>
  36. void multi_tensor_apply(int block_size,
  37. int chunk_size,
  38. const at::Tensor& noop_flag,
  39. const std::vector<std::vector<at::Tensor>>& tensor_lists,
  40. T callable,
  41. ArgTypes... args)
  42. {
  43. TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
  44. int len0 = tensor_lists[0].size();
  45. TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
  46. auto ref_device = tensor_lists[0][0].device();
  47. TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
  48. for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
  49. {
  50. TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
  51. for (int t = 0; t < tensor_lists[l].size(); t++) {
  52. // TODO: Print which tensor fails.
  53. bool contiguous_memory = tensor_lists[l][t].is_contiguous();
  54. #ifdef VERSION_GE_1_5
  55. contiguous_memory = (contiguous_memory ||
  56. tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
  57. #endif
  58. TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
  59. TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
  60. "A tensor was not on the same device as the first tensor");
  61. TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
  62. }
  63. }
  64. int ntensors = tensor_lists[0].size();
  65. TensorListMetadata<depth> tl;
  66. const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
  67. auto stream = at::cuda::getCurrentCUDAStream();
  68. tl.start_tensor_this_launch = 0;
  69. int loc_block_info = 0;
  70. int loc_tensor_info = 0;
  71. for (int t = 0; t < ntensors; t++) {
  72. tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
  73. for (int d = 0; d < depth; d++)
  74. tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
  75. loc_tensor_info++;
  76. int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
  77. for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
  78. // std::cout << chunks_this_tensor << std::endl;
  79. tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
  80. tl.block_to_chunk[loc_block_info] = chunk;
  81. loc_block_info++;
  82. bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
  83. chunk == chunks_this_tensor - 1);
  84. bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
  85. bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
  86. if (tensors_full || blocks_full || last_chunk) {
  87. // using accscalar_t = acc_type<scalar_t, true>;
  88. multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
  89. chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
  90. AT_CUDA_CHECK(cudaGetLastError());
  91. // Reset. The control flow possibilities here make my brain hurt.
  92. loc_block_info = 0;
  93. if (chunk == chunks_this_tensor - 1) {
  94. // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
  95. // std::endl;
  96. loc_tensor_info = 0;
  97. tl.start_tensor_this_launch = t + 1;
  98. } else {
  99. // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
  100. // std::endl;
  101. tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
  102. for (int d = 0; d < depth; d++)
  103. tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
  104. loc_tensor_info = 1;
  105. tl.start_tensor_this_launch = t;
  106. }
  107. }
  108. }
  109. }
  110. }