123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from deepspeed import comm as dist
- import cupy
- import numpy as np
- from deepspeed.runtime.compression.cupy import CupyBackend
- from deepspeed.utils.torch import required_torch_version
- from deepspeed.accelerator import get_accelerator
- class NcclBackend(object):
- def __init__(self, mpu=None):
- if mpu is None:
- self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
- else:
- self.mpu = mpu
- self.world_group = self.mpu.get_data_parallel_group()
- self.rank = dist.get_rank(group=self.world_group)
- self.size = dist.get_world_size(group=self.world_group)
- self.compression_backend = CupyBackend()
- self.bool_not_supported = required_torch_version(min_version=1.10)
- def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
- req = []
- if rank == root:
- for idx in range(size):
- if idx != rank:
- req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
- else:
- recvbuf[rank] = sendbuf
- else:
- req.append(dist.isend(sendbuf, group=group, dst=root))
- return req
- def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
- if rank == root:
- for idx in range(size):
- if idx != rank:
- dist.recv(recvbuf[idx], src=idx, group=group)
- else:
- recvbuf[rank] = sendbuf
- else:
- dist.send(sendbuf, group=group, dst=root)
- def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
- # all_start_time = time.time()
- original_shape = buffer_m.size()
- if len(original_shape) > 1:
- buffer_m = torch.flatten(buffer_m)
- original_size = buffer_m.numel()
- worker_error_size = worker_error.numel()
- cupy.cuda.Device(local_rank).use()
- if original_size != worker_error_size:
- empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
- buffer_m = torch.cat([buffer_m, empty_tensor])
- buffer_m.add_(worker_error)
- worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(buffer_m.numel())
- worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
- if self.bool_not_supported:
- cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
- self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool().to(dtype=torch.uint8)), self.size)
- else:
- cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
- self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size)
- cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
- cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size],
- dtype=cupy_sign_list_packed[0].dtype)
- # cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
- sign_list_packed = [
- self.compression_backend.cupy2torch(cupy_sign_list_packed[idx]) for idx in range(self.size)
- ]
- # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale)
- recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign)
- #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale)
- recvbuf_scale = [
- torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(get_accelerator().device_name(local_rank)))
- for i in range(self.size)
- ]
- # communication phase 1
- # gather_start = time.time()
- # Alltoall for sign
- dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
- # Allgather for scale
- dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)
- # gather_end = time.time()
- # cupy_sign_list_packed, sign_list_packed, cupy_worker_scale, worker_scale = None, None, None, None
- cupy_sign_list_packed = None
- cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign)
- #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale))
- compensated_server_m = self.compression_backend.cupy2torch(
- (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
- torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
- compensated_server_m.add_(server_error)
- server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
- server_error.set_(compensated_server_m -
- server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
- # cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
- if self.bool_not_supported:
- cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
- self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool().to(dtype=torch.uint8)),
- 1)
- else:
- cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
- self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1)
- compensated_server_m = None
- cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size],
- dtype=cupy_recvbuf_sign.dtype)
- # cupy_recvbuf_sign, recvbuf_sign = None, None
- cupy_recvbuf_sign = None
- server_sign_packed = [self.compression_backend.cupy2torch(cupy_server_sign_packed[0])]
- recvbuf_sign_server = [
- self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx]) for idx in range(self.size)
- ]
- # server_scale = self.compression_backend.cupy2torch(cupy_server_scale)
- cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
- # cupy_recvbuf_scale, recvbuf_scale = None, None
- recvbuf_scale_server = [
- self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx]) for idx in range(self.size)
- ]
- # Communication Phase 2
- dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
- dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)
- cupy_server_sign_packed = None
- # need to convert from a tensor list to a single tensor
- # dist.all_gather only provides a tensor list as the recv/output buffer
- recvbuf_sign_server = torch.stack(recvbuf_sign_server)
- cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(recvbuf_sign_server)
- buffer_m.data.copy_(
- self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
- self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
- self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data)
- if original_size != worker_error_size:
- buffer_m = buffer_m[0:original_size]
- if len(original_shape) > 1:
- buffer_m = buffer_m.reshape(original_shape)
- return buffer_m
|