nccl.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed import comm as dist
  6. import cupy
  7. import numpy as np
  8. from deepspeed.runtime.compression.cupy import CupyBackend
  9. from deepspeed.runtime.utils import required_torch_version
  10. from deepspeed.accelerator import get_accelerator
  11. class NcclBackend(object):
  12. def __init__(self, mpu=None):
  13. if mpu is None:
  14. self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
  15. else:
  16. self.mpu = mpu
  17. self.world_group = self.mpu.get_data_parallel_group()
  18. self.rank = dist.get_rank(group=self.world_group)
  19. self.size = dist.get_world_size(group=self.world_group)
  20. self.compression_backend = CupyBackend()
  21. self.bool_not_supported = required_torch_version(min_version=1.10)
  22. def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
  23. req = []
  24. if rank == root:
  25. for idx in range(size):
  26. if idx != rank:
  27. req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
  28. else:
  29. recvbuf[rank] = sendbuf
  30. else:
  31. req.append(dist.isend(sendbuf, group=group, dst=root))
  32. return req
  33. def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
  34. if rank == root:
  35. for idx in range(size):
  36. if idx != rank:
  37. dist.recv(recvbuf[idx], src=idx, group=group)
  38. else:
  39. recvbuf[rank] = sendbuf
  40. else:
  41. dist.send(sendbuf, group=group, dst=root)
  42. def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
  43. # all_start_time = time.time()
  44. original_shape = buffer_m.size()
  45. if len(original_shape) > 1:
  46. buffer_m = torch.flatten(buffer_m)
  47. original_size = buffer_m.numel()
  48. worker_error_size = worker_error.numel()
  49. cupy.cuda.Device(local_rank).use()
  50. if original_size != worker_error_size:
  51. empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
  52. buffer_m = torch.cat([buffer_m, empty_tensor])
  53. buffer_m.add_(worker_error)
  54. worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(buffer_m.numel())
  55. worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
  56. if self.bool_not_supported:
  57. cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
  58. self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool().to(dtype=torch.uint8)), self.size)
  59. else:
  60. cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
  61. self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size)
  62. cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
  63. cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size],
  64. dtype=cupy_sign_list_packed[0].dtype)
  65. # cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
  66. sign_list_packed = [
  67. self.compression_backend.cupy2torch(cupy_sign_list_packed[idx]) for idx in range(self.size)
  68. ]
  69. # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale)
  70. recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign)
  71. #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale)
  72. recvbuf_scale = [
  73. torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(get_accelerator().device_name(local_rank)))
  74. for i in range(self.size)
  75. ]
  76. # communication phase 1
  77. # gather_start = time.time()
  78. # Alltoall for sign
  79. dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
  80. # Allgather for scale
  81. dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)
  82. # gather_end = time.time()
  83. # cupy_sign_list_packed, sign_list_packed, cupy_worker_scale, worker_scale = None, None, None, None
  84. cupy_sign_list_packed = None
  85. cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign)
  86. #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale))
  87. compensated_server_m = self.compression_backend.cupy2torch(
  88. (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
  89. torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
  90. compensated_server_m.add_(server_error)
  91. server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
  92. server_error.set_(compensated_server_m -
  93. server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
  94. # cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
  95. if self.bool_not_supported:
  96. cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
  97. self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool().to(dtype=torch.uint8)),
  98. 1)
  99. else:
  100. cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
  101. self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1)
  102. compensated_server_m = None
  103. cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size],
  104. dtype=cupy_recvbuf_sign.dtype)
  105. # cupy_recvbuf_sign, recvbuf_sign = None, None
  106. cupy_recvbuf_sign = None
  107. server_sign_packed = [self.compression_backend.cupy2torch(cupy_server_sign_packed[0])]
  108. recvbuf_sign_server = [
  109. self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx]) for idx in range(self.size)
  110. ]
  111. # server_scale = self.compression_backend.cupy2torch(cupy_server_scale)
  112. cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
  113. # cupy_recvbuf_scale, recvbuf_scale = None, None
  114. recvbuf_scale_server = [
  115. self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx]) for idx in range(self.size)
  116. ]
  117. # Communication Phase 2
  118. dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
  119. dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)
  120. cupy_server_sign_packed = None
  121. # need to convert from a tensor list to a single tensor
  122. # dist.all_gather only provides a tensor list as the recv/output buffer
  123. recvbuf_sign_server = torch.stack(recvbuf_sign_server)
  124. cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(recvbuf_sign_server)
  125. buffer_m.data.copy_(
  126. self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
  127. self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
  128. self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data)
  129. if original_size != worker_error_size:
  130. buffer_m = buffer_m[0:original_size]
  131. if len(original_shape) > 1:
  132. buffer_m = buffer_m.reshape(original_shape)
  133. return buffer_m