mpi.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. import cupy
  6. import time
  7. import numpy as np
  8. from mpi4py import MPI
  9. from deepspeed.runtime.compression.cupy import CupyBackend
  10. class MpiBackend(object):
  11. def __init__(self, cuda_aware):
  12. self.comm = MPI.COMM_WORLD
  13. self.rank = self.comm.Get_rank()
  14. self.size = self.comm.Get_size()
  15. self.cuda_aware = cuda_aware
  16. self.compression_backend = CupyBackend()
  17. def my_igather(self, rank, size, comm, sendbuf, recbuf, root):
  18. req = []
  19. if rank == root:
  20. for idx in range(size):
  21. if idx != rank:
  22. req.append(comm.Irecv(recbuf[idx], source=idx))
  23. else:
  24. recbuf[rank] = sendbuf
  25. else:
  26. req.append(comm.Isend(sendbuf, dest=root))
  27. return req
  28. def gather_cuda(self,
  29. rank,
  30. world_size,
  31. comm,
  32. cupy_sign_list_packed,
  33. cupy_recvbuf_sign,
  34. cupy_worker_scale,
  35. cupy_recvbuf_scale):
  36. # We do in-place operations on cupy buffers so we do not return any buffers
  37. requests = []
  38. for idx in range(world_size):
  39. req_sign = self.my_igather(rank,
  40. world_size,
  41. comm,
  42. cupy_sign_list_packed[idx],
  43. cupy_recvbuf_sign,
  44. root=idx)
  45. requests += req_sign
  46. for idx in range(world_size):
  47. req_scale = self.my_igather(rank,
  48. world_size,
  49. comm,
  50. cupy_worker_scale,
  51. cupy_recvbuf_scale,
  52. root=idx)
  53. requests += req_scale
  54. MPI.Request.Waitall(requests)
  55. def gather_host(self,
  56. rank,
  57. world_size,
  58. comm,
  59. cupy_sign_list_packed,
  60. cupy_recvbuf_sign,
  61. cupy_worker_scale,
  62. cupy_recvbuf_scale):
  63. # In-place operations are not possible for newly created cupy arrays
  64. # so we need to return the new buffers
  65. numpy_recvbuf_sign = np.zeros([world_size,
  66. cupy_sign_list_packed[rank].size],
  67. dtype=cupy_sign_list_packed[0].dtype)
  68. numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
  69. # 1. convert from cupy to numpy
  70. numpy_sign_list_packed = cupy_sign_list_packed
  71. for idx in range(world_size):
  72. numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
  73. numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
  74. numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
  75. cupy.cuda.get_current_stream().synchronize()
  76. # 2. use numpy buffers for communication
  77. requests = []
  78. for idx in range(world_size):
  79. req_sign = self.my_igather(rank,
  80. world_size,
  81. comm,
  82. numpy_sign_list_packed[idx],
  83. numpy_recvbuf_sign,
  84. root=idx)
  85. requests += req_sign
  86. for idx in range(world_size):
  87. req_scale = self.my_igather(rank,
  88. world_size,
  89. comm,
  90. numpy_worker_scale,
  91. numpy_recvbuf_scale,
  92. root=idx)
  93. requests += req_scale
  94. MPI.Request.Waitall(requests)
  95. # 3. Convert back from numpy to cupy
  96. cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
  97. for idx in range(world_size):
  98. cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
  99. cupy_worker_scale = cupy.asarray(numpy_worker_scale)
  100. cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
  101. cupy.cuda.get_current_stream().synchronize()
  102. return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
  103. def allgather_cuda(self,
  104. comm,
  105. cupy_server_sign_packed,
  106. cupy_recvbuf_sign_server,
  107. cupy_server_scale,
  108. cupy_recvbuf_scale_server):
  109. comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
  110. comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
  111. def allgather_host(self,
  112. comm,
  113. cupy_server_sign_packed,
  114. cupy_recvbuf_sign_server,
  115. cupy_server_scale,
  116. cupy_recvbuf_scale_server):
  117. # 1. Convert cupy to numpy
  118. numpy_recvbuf_sign_server = np.zeros(
  119. [comm.Get_size(),
  120. cupy_server_sign_packed.size],
  121. dtype=cupy_server_sign_packed.dtype)
  122. numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
  123. 1],
  124. dtype=cupy_server_scale.dtype)
  125. numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
  126. numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
  127. numpy_server_scale = cupy.asnumpy(cupy_server_scale)
  128. numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
  129. cupy.cuda.get_current_stream().synchronize()
  130. # 2. Communicate numpy buffers
  131. comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
  132. comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
  133. comm.Barrier()
  134. # 3. Convert numpy back to cupy
  135. cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
  136. cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
  137. cupy_server_scale = cupy.asarray(numpy_server_scale)
  138. cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
  139. cupy.cuda.get_current_stream().synchronize()
  140. return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
  141. def compressed_allreduce(self,
  142. buffer_m: torch.tensor,
  143. worker_error,
  144. server_error,
  145. local_rank):
  146. all_start_time = time.time()
  147. original_shape = buffer_m.size()
  148. if len(original_shape) > 1:
  149. buffer_m = torch.flatten(buffer_m)
  150. original_size = buffer_m.numel()
  151. worker_error_size = worker_error.numel()
  152. cupy.cuda.Device(local_rank).use()
  153. if original_size != worker_error_size:
  154. empty_tensor = torch.zeros(worker_error_size - original_size,
  155. device=buffer_m.device)
  156. buffer_m = torch.cat([buffer_m, empty_tensor])
  157. buffer_m.add_(worker_error)
  158. worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
  159. worker_error.set_(buffer_m - worker_scale *
  160. buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
  161. cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
  162. self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()),
  163. self.size)
  164. cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
  165. cupy_recvbuf_sign = cupy.zeros(
  166. [self.size,
  167. cupy_sign_list_packed[self.rank].size],
  168. dtype=cupy_sign_list_packed[0].dtype)
  169. cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
  170. # Communication Phase 1
  171. gather_start = time.time()
  172. if self.cuda_aware:
  173. self.gather_cuda(self.rank,
  174. self.size,
  175. self.comm,
  176. cupy_sign_list_packed,
  177. cupy_recvbuf_sign,
  178. cupy_worker_scale,
  179. cupy_recvbuf_scale)
  180. else:
  181. _, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank,
  182. self.size,
  183. self.comm,
  184. cupy_sign_list_packed,
  185. cupy_recvbuf_sign,
  186. cupy_worker_scale,
  187. cupy_recvbuf_scale)
  188. gather_end = time.time()
  189. # cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None
  190. cupy_sign_list_packed = None
  191. compensated_server_m = self.compression_backend.cupy2torch(
  192. (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
  193. self.size,
  194. -1)).float().add_(-0.5).mul_(2.0).mul_(
  195. self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(
  196. 1 / self.size)).sum(0)
  197. compensated_server_m.add_(server_error)
  198. server_scale = torch.norm(compensated_server_m) / np.sqrt(
  199. compensated_server_m.numel())
  200. server_error.set_(
  201. compensated_server_m - server_scale *
  202. compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
  203. cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
  204. cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
  205. self.compression_backend.torch2cupy(
  206. compensated_server_m.sign_().add_(1).bool()),
  207. 1)
  208. compensated_server_m = None
  209. cupy_recvbuf_sign_server = cupy.zeros(
  210. [self.size,
  211. cupy_server_sign_packed[0].size],
  212. dtype=cupy_recvbuf_sign.dtype)
  213. cupy_recvbuf_scale_server = cupy.zeros([self.size,
  214. 1],
  215. dtype=cupy_recvbuf_scale.dtype)
  216. # cupy_recvbuf_sign, cupy_recvbuf_scale = None, None
  217. cupy_recvbuf_sign = None
  218. # Communication Phase 2
  219. if self.cuda_aware:
  220. self.allgather_cuda(self.comm,
  221. cupy_server_sign_packed[0],
  222. cupy_recvbuf_sign_server,
  223. cupy_server_scale,
  224. cupy_recvbuf_scale_server)
  225. else:
  226. _, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host(self.comm,
  227. cupy_server_sign_packed[0],
  228. cupy_recvbuf_sign_server,
  229. cupy_server_scale,
  230. cupy_recvbuf_scale_server)
  231. # cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None
  232. cupy_server_sign_packed = None
  233. buffer_m.data.copy_(
  234. self.compression_backend.cupy2torch(
  235. (cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
  236. self.size,
  237. -1)).float().add_(-0.5).mul_(2.0).mul_(
  238. self.compression_backend.cupy2torch(
  239. cupy_recvbuf_scale_server)).flatten().data)
  240. if original_size != worker_error_size:
  241. buffer_m = buffer_m[0:original_size]
  242. if len(original_shape) > 1:
  243. buffer_m = buffer_m.reshape(original_shape)
  244. # cupy_recvbuf_sign_server, cupy_recvbuf_scale_server = None, None
  245. return buffer_m