deepspeed_py_copy.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /*
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Licensed under the MIT license.
  4. Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
  5. */
  6. #include "deepspeed_py_copy.h"
  7. #include <omp.h>
  8. #define ROUND_DOWN(size, step) ((size) & ~((step)-1))
  9. #if defined(__AVX512__) or defined(__AVX256__)
  10. union AVX_Data {
  11. #if defined(__AVX512__)
  12. __m512 data;
  13. #else
  14. __m256 data;
  15. #endif
  16. };
  17. #endif
  18. static void helper_memcpy_1(float* dest, float* src, size_t param_size)
  19. {
  20. size_t rounded_size = 0;
  21. #if defined(__AVX512__) or defined(__AVX256__)
  22. rounded_size = ROUND_DOWN(param_size, SIMD_WIDTH);
  23. for (size_t t = 0; t < rounded_size; t += TILE) {
  24. size_t copy_size = TILE;
  25. if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
  26. size_t offset = copy_size + t;
  27. #pragma omp parallel for
  28. for (size_t i = t; i < offset; i += SIMD_WIDTH) {
  29. AVX_Data src_4;
  30. src_4.data = SIMD_LOAD(src + i);
  31. SIMD_STORE(dest + i, src_4.data);
  32. }
  33. }
  34. #endif
  35. if (param_size > rounded_size) {
  36. #pragma omp parallel for
  37. for (size_t k = rounded_size; k < param_size; k++) { dest[k] = src[k]; }
  38. }
  39. }
  40. static void helper_memcpy_4(float* dest, float* src, size_t param_size)
  41. {
  42. size_t rounded_size = 0;
  43. #if defined(__AVX512__) or defined(__AVX256__)
  44. rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2));
  45. for (size_t t = 0; t < rounded_size; t += TILE) {
  46. size_t copy_size = TILE;
  47. if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
  48. size_t offset = copy_size + t;
  49. #pragma omp parallel for
  50. for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
  51. AVX_Data src_4[4];
  52. src_4[0].data = SIMD_LOAD(src + i);
  53. src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH);
  54. src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1));
  55. src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3);
  56. SIMD_STORE(dest + i, src_4[0].data);
  57. SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data);
  58. SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data);
  59. SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data);
  60. }
  61. }
  62. #endif
  63. if (param_size > rounded_size)
  64. helper_memcpy_1((dest + rounded_size), (src + rounded_size), (param_size - rounded_size));
  65. }
  66. static void helper_mempcy_8(float* dest, float* src, size_t param_size)
  67. {
  68. size_t rounded_size = 0;
  69. #if defined(__AVX512__) or defined(__AVX256__)
  70. rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2));
  71. for (size_t t = 0; t < rounded_size; t += TILE) {
  72. size_t copy_size = TILE;
  73. if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
  74. size_t offset = copy_size + t;
  75. #pragma omp parallel for
  76. for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
  77. AVX_Data src_4[8];
  78. src_4[0].data = SIMD_LOAD(src + i);
  79. src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH);
  80. src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1));
  81. src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3);
  82. src_4[4].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 2));
  83. src_4[5].data = SIMD_LOAD(src + i + SIMD_WIDTH * 5);
  84. src_4[6].data = SIMD_LOAD(src + i + SIMD_WIDTH * 6);
  85. src_4[7].data = SIMD_LOAD(src + i + SIMD_WIDTH * 7);
  86. SIMD_STORE(dest + i, src_4[0].data);
  87. SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data);
  88. SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data);
  89. SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data);
  90. SIMD_STORE(dest + i + (SIMD_WIDTH << 2), src_4[4].data);
  91. SIMD_STORE(dest + i + SIMD_WIDTH * 5, src_4[5].data);
  92. SIMD_STORE(dest + i + SIMD_WIDTH * 6, src_4[6].data);
  93. SIMD_STORE(dest + i + SIMD_WIDTH * 7, src_4[7].data);
  94. }
  95. }
  96. #endif
  97. if (param_size > rounded_size)
  98. helper_memcpy_4((dest + rounded_size), (src + rounded_size), (param_size - rounded_size));
  99. }
  100. int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src)
  101. {
  102. auto dest_c = dest.contiguous();
  103. auto src_c = src.contiguous();
  104. float* dest_ptr = (float*)dest_c.data_ptr();
  105. float* src_ptr = (float*)src_c.data_ptr();
  106. helper_mempcy_8(dest_ptr, src_ptr, dest_c.size(0));
  107. return 0;
  108. }