mpi.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import cupy
  6. import time
  7. import numpy as np
  8. from mpi4py import MPI
  9. from deepspeed.runtime.compression.cupy import CupyBackend
  10. class MpiBackend(object):
  11. def __init__(self, cuda_aware):
  12. self.comm = MPI.COMM_WORLD
  13. self.rank = self.comm.Get_rank()
  14. self.size = self.comm.Get_size()
  15. self.cuda_aware = cuda_aware
  16. self.compression_backend = CupyBackend()
  17. def my_igather(self, rank, size, comm, sendbuf, recbuf, root):
  18. req = []
  19. if rank == root:
  20. for idx in range(size):
  21. if idx != rank:
  22. req.append(comm.Irecv(recbuf[idx], source=idx))
  23. else:
  24. recbuf[rank] = sendbuf
  25. else:
  26. req.append(comm.Isend(sendbuf, dest=root))
  27. return req
  28. def gather_cuda(self, rank, world_size, comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale,
  29. cupy_recvbuf_scale):
  30. # We do in-place operations on cupy buffers so we do not return any buffers
  31. requests = []
  32. for idx in range(world_size):
  33. req_sign = self.my_igather(rank, world_size, comm, cupy_sign_list_packed[idx], cupy_recvbuf_sign, root=idx)
  34. requests += req_sign
  35. for idx in range(world_size):
  36. req_scale = self.my_igather(rank, world_size, comm, cupy_worker_scale, cupy_recvbuf_scale, root=idx)
  37. requests += req_scale
  38. MPI.Request.Waitall(requests)
  39. def gather_host(self, rank, world_size, comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale,
  40. cupy_recvbuf_scale):
  41. # In-place operations are not possible for newly created cupy arrays
  42. # so we need to return the new buffers
  43. numpy_recvbuf_sign = np.zeros([world_size, cupy_sign_list_packed[rank].size],
  44. dtype=cupy_sign_list_packed[0].dtype)
  45. numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
  46. # 1. convert from cupy to numpy
  47. numpy_sign_list_packed = cupy_sign_list_packed
  48. for idx in range(world_size):
  49. numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
  50. numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
  51. numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
  52. cupy.cuda.get_current_stream().synchronize()
  53. # 2. use numpy buffers for communication
  54. requests = []
  55. for idx in range(world_size):
  56. req_sign = self.my_igather(rank,
  57. world_size,
  58. comm,
  59. numpy_sign_list_packed[idx],
  60. numpy_recvbuf_sign,
  61. root=idx)
  62. requests += req_sign
  63. for idx in range(world_size):
  64. req_scale = self.my_igather(rank, world_size, comm, numpy_worker_scale, numpy_recvbuf_scale, root=idx)
  65. requests += req_scale
  66. MPI.Request.Waitall(requests)
  67. # 3. Convert back from numpy to cupy
  68. cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
  69. for idx in range(world_size):
  70. cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
  71. cupy_worker_scale = cupy.asarray(numpy_worker_scale)
  72. cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
  73. cupy.cuda.get_current_stream().synchronize()
  74. return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
  75. def allgather_cuda(self, comm, cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale,
  76. cupy_recvbuf_scale_server):
  77. comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
  78. comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
  79. def allgather_host(self, comm, cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale,
  80. cupy_recvbuf_scale_server):
  81. # 1. Convert cupy to numpy
  82. numpy_recvbuf_sign_server = np.zeros([comm.Get_size(), cupy_server_sign_packed.size],
  83. dtype=cupy_server_sign_packed.dtype)
  84. numpy_recvbuf_scale_server = np.zeros([comm.Get_size(), 1], dtype=cupy_server_scale.dtype)
  85. numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
  86. numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
  87. numpy_server_scale = cupy.asnumpy(cupy_server_scale)
  88. numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
  89. cupy.cuda.get_current_stream().synchronize()
  90. # 2. Communicate numpy buffers
  91. comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
  92. comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
  93. comm.Barrier()
  94. # 3. Convert numpy back to cupy
  95. cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
  96. cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
  97. cupy_server_scale = cupy.asarray(numpy_server_scale)
  98. cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
  99. cupy.cuda.get_current_stream().synchronize()
  100. return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
  101. def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
  102. all_start_time = time.time()
  103. original_shape = buffer_m.size()
  104. if len(original_shape) > 1:
  105. buffer_m = torch.flatten(buffer_m)
  106. original_size = buffer_m.numel()
  107. worker_error_size = worker_error.numel()
  108. cupy.cuda.Device(local_rank).use()
  109. if original_size != worker_error_size:
  110. empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
  111. buffer_m = torch.cat([buffer_m, empty_tensor])
  112. buffer_m.add_(worker_error)
  113. worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
  114. worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
  115. cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
  116. self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size)
  117. cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
  118. cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size],
  119. dtype=cupy_sign_list_packed[0].dtype)
  120. cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
  121. # Communication Phase 1
  122. gather_start = time.time()
  123. if self.cuda_aware:
  124. self.gather_cuda(self.rank, self.size, self.comm, cupy_sign_list_packed, cupy_recvbuf_sign,
  125. cupy_worker_scale, cupy_recvbuf_scale)
  126. else:
  127. _, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank, self.size, self.comm,
  128. cupy_sign_list_packed, cupy_recvbuf_sign,
  129. cupy_worker_scale, cupy_recvbuf_scale)
  130. gather_end = time.time()
  131. # cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None
  132. cupy_sign_list_packed = None
  133. compensated_server_m = self.compression_backend.cupy2torch(
  134. (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
  135. self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(1 / self.size)).sum(0)
  136. compensated_server_m.add_(server_error)
  137. server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
  138. server_error.set_(compensated_server_m -
  139. server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
  140. cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
  141. cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
  142. self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1)
  143. compensated_server_m = None
  144. cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size],
  145. dtype=cupy_recvbuf_sign.dtype)
  146. cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_recvbuf_scale.dtype)
  147. # cupy_recvbuf_sign, cupy_recvbuf_scale = None, None
  148. cupy_recvbuf_sign = None
  149. # Communication Phase 2
  150. if self.cuda_aware:
  151. self.allgather_cuda(self.comm, cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale,
  152. cupy_recvbuf_scale_server)
  153. else:
  154. _, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host(
  155. self.comm, cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale,
  156. cupy_recvbuf_scale_server)
  157. # cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None
  158. cupy_server_sign_packed = None
  159. buffer_m.data.copy_(
  160. self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
  161. self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
  162. self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data)
  163. if original_size != worker_error_size:
  164. buffer_m = buffer_m[0:original_size]
  165. if len(original_shape) > 1:
  166. buffer_m = buffer_m.reshape(original_shape)
  167. # cupy_recvbuf_sign_server, cupy_recvbuf_scale_server = None, None
  168. return buffer_m