utils.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  6. https:github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
  7. */
  8. #include <torch/extension.h>
  9. #include <string>
  10. #include <tuple>
  11. #include <vector>
  12. #ifdef _OPENMP
  13. #include <omp.h>
  14. #endif
  15. typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
  16. void segment_blocks(torch::Tensor layout,
  17. torch::Tensor idx,
  18. torch::Tensor scratch,
  19. int max_width,
  20. ret_t& ret)
  21. {
  22. size_t H = layout.size(0);
  23. size_t M = layout.size(1);
  24. size_t N = layout.size(2);
  25. torch::Tensor tmp = torch::zeros_like(layout);
  26. auto _tmp = tmp.accessor<int, 3>();
  27. auto _layout = layout.accessor<int, 3>();
  28. auto _idx = idx.accessor<int, 3>();
  29. auto _scratch = scratch.accessor<int, 3>();
  30. std::vector<int> current(H, 0);
  31. #ifdef _OPENMP
  32. #pragma omp parallel for
  33. #endif
  34. for (size_t h = 0; h < H; h++) {
  35. // surrounding indices
  36. std::vector<int> ii_left(max_width, -1);
  37. std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
  38. for (size_t m = 0; m < M; m++) {
  39. for (size_t n = 0; n < N; n++) {
  40. int v = _layout[h][m][n];
  41. if (v == 0) continue;
  42. int n_left = ii_left[max_width - 1];
  43. int m_top = ii_top[max_width - 1][n];
  44. int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
  45. int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
  46. int topleft = (m_top >= 0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
  47. int width = std::min(left, std::min(top, topleft)) + 1;
  48. // reset width if blocks cannot be
  49. // packed together (i.e., there's a 1 "in the middle")
  50. for (int nn = n_left + 1; nn < n; nn++)
  51. if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) width = 1;
  52. _tmp[h][m][n] = width;
  53. // update n_left ring buffer
  54. for (int k = 0; k < max_width - 1; k++) ii_left[k] = ii_left[k + 1];
  55. ii_left[max_width - 1] = n;
  56. // update ii_top ring buffer
  57. for (int k = 0; k < max_width - 1; k++) ii_top[k][n] = ii_top[k + 1][n];
  58. ii_top[max_width - 1][n] = m;
  59. // block is too small -- skip
  60. if (width != max_width) continue;
  61. // retained blocks are set to zeros
  62. for (size_t km = 0; km < max_width; km++)
  63. for (size_t kn = 0; kn < max_width; kn++) {
  64. int mm = ii_top[km][n];
  65. int nn = ii_left[kn];
  66. if (mm < 0 || nn < 0) continue;
  67. _layout[h][mm][nn] = 0;
  68. _tmp[h][mm][nn] = 0;
  69. _scratch[h][current[h]][0] = (int)h;
  70. _scratch[h][current[h]][1] = (int)mm;
  71. _scratch[h][current[h]][2] = (int)nn;
  72. _scratch[h][current[h]][3] = _idx[h][mm][nn];
  73. current[h]++;
  74. }
  75. }
  76. }
  77. }
  78. std::vector<torch::Tensor> to_cat;
  79. for (size_t h = 0; h < H; h++)
  80. if (current[h] > 0) to_cat.push_back(scratch[h].slice(0, 0, current[h]));
  81. if (!to_cat.empty()) ret.push_back({max_width, torch::cat(to_cat)});
  82. }
  83. ret_t sdd_segment(torch::Tensor layout, int start_width)
  84. {
  85. ret_t ret;
  86. // block index
  87. torch::Tensor idx = torch::zeros_like(layout);
  88. int current = 0;
  89. int64_t H = layout.size(0);
  90. int64_t M = layout.size(1);
  91. int64_t N = layout.size(2);
  92. auto _layout = layout.accessor<int, 3>();
  93. auto _idx = idx.accessor<int, 3>();
  94. for (int64_t h = 0; h < H; h++)
  95. for (int64_t m = 0; m < M; m++)
  96. for (int64_t n = 0; n < N; n++) {
  97. if (_layout[h][m][n] == 0) continue;
  98. _idx[h][m][n] = current++;
  99. }
  100. // scratch memory
  101. torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
  102. for (int max_width = start_width; max_width > 0; max_width /= 2)
  103. segment_blocks(layout, idx, scratch, max_width, ret);
  104. return ret;
  105. }
  106. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  107. {
  108. m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
  109. }