deepspeed_py_copy.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
  6. */
  7. #include "deepspeed_py_copy.h"
  8. #include <omp.h>
  9. #define ROUND_DOWN(size, step) ((size) & ~((step)-1))
  10. #if defined(__AVX512__) or defined(__AVX256__)
  11. union AVX_Data {
  12. #if defined(__AVX512__)
  13. __m512 data;
  14. #else
  15. __m256 data;
  16. #endif
  17. };
  18. #endif
  19. static void helper_memcpy_1(float* dest, float* src, size_t param_size)
  20. {
  21. size_t rounded_size = 0;
  22. #if defined(__AVX512__) or defined(__AVX256__)
  23. rounded_size = ROUND_DOWN(param_size, SIMD_WIDTH);
  24. for (size_t t = 0; t < rounded_size; t += TILE) {
  25. size_t copy_size = TILE;
  26. if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
  27. size_t offset = copy_size + t;
  28. #pragma omp parallel for
  29. for (size_t i = t; i < offset; i += SIMD_WIDTH) {
  30. AVX_Data src_4;
  31. src_4.data = SIMD_LOAD(src + i);
  32. SIMD_STORE(dest + i, src_4.data);
  33. }
  34. }
  35. #endif
  36. if (param_size > rounded_size) {
  37. #pragma omp parallel for
  38. for (size_t k = rounded_size; k < param_size; k++) { dest[k] = src[k]; }
  39. }
  40. }
  41. static void helper_memcpy_4(float* dest, float* src, size_t param_size)
  42. {
  43. size_t rounded_size = 0;
  44. #if defined(__AVX512__) or defined(__AVX256__)
  45. rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2));
  46. for (size_t t = 0; t < rounded_size; t += TILE) {
  47. size_t copy_size = TILE;
  48. if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
  49. size_t offset = copy_size + t;
  50. #pragma omp parallel for
  51. for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
  52. AVX_Data src_4[4];
  53. src_4[0].data = SIMD_LOAD(src + i);
  54. src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH);
  55. src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1));
  56. src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3);
  57. SIMD_STORE(dest + i, src_4[0].data);
  58. SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data);
  59. SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data);
  60. SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data);
  61. }
  62. }
  63. #endif
  64. if (param_size > rounded_size)
  65. helper_memcpy_1((dest + rounded_size), (src + rounded_size), (param_size - rounded_size));
  66. }
  67. static void helper_mempcy_8(float* dest, float* src, size_t param_size)
  68. {
  69. size_t rounded_size = 0;
  70. #if defined(__AVX512__) or defined(__AVX256__)
  71. rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2));
  72. for (size_t t = 0; t < rounded_size; t += TILE) {
  73. size_t copy_size = TILE;
  74. if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
  75. size_t offset = copy_size + t;
  76. #pragma omp parallel for
  77. for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
  78. AVX_Data src_4[8];
  79. src_4[0].data = SIMD_LOAD(src + i);
  80. src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH);
  81. src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1));
  82. src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3);
  83. src_4[4].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 2));
  84. src_4[5].data = SIMD_LOAD(src + i + SIMD_WIDTH * 5);
  85. src_4[6].data = SIMD_LOAD(src + i + SIMD_WIDTH * 6);
  86. src_4[7].data = SIMD_LOAD(src + i + SIMD_WIDTH * 7);
  87. SIMD_STORE(dest + i, src_4[0].data);
  88. SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data);
  89. SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data);
  90. SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data);
  91. SIMD_STORE(dest + i + (SIMD_WIDTH << 2), src_4[4].data);
  92. SIMD_STORE(dest + i + SIMD_WIDTH * 5, src_4[5].data);
  93. SIMD_STORE(dest + i + SIMD_WIDTH * 6, src_4[6].data);
  94. SIMD_STORE(dest + i + SIMD_WIDTH * 7, src_4[7].data);
  95. }
  96. }
  97. #endif
  98. if (param_size > rounded_size)
  99. helper_memcpy_4((dest + rounded_size), (src + rounded_size), (param_size - rounded_size));
  100. }
  101. int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src)
  102. {
  103. auto dest_c = dest.contiguous();
  104. auto src_c = src.contiguous();
  105. float* dest_ptr = (float*)dest_c.data_ptr();
  106. float* src_ptr = (float*)src_c.data_ptr();
  107. helper_mempcy_8(dest_ptr, src_ptr, dest_c.size(0));
  108. return 0;
  109. }