# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch from deepspeed import comm as dist import cupy import numpy as np from deepspeed.runtime.compression.cupy import CupyBackend from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator class NcclBackend(object): def __init__(self, mpu=None): if mpu is None: self.world_group = dist.new_group(ranks=range(dist.get_world_size())) else: self.mpu = mpu self.world_group = self.mpu.get_data_parallel_group() self.rank = dist.get_rank(group=self.world_group) self.size = dist.get_world_size(group=self.world_group) self.compression_backend = CupyBackend() self.bool_not_supported = required_torch_version(min_version=1.10) def my_igather(self, rank, size, group, sendbuf, recvbuf, root): req = [] if rank == root: for idx in range(size): if idx != rank: req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) else: recvbuf[rank] = sendbuf else: req.append(dist.isend(sendbuf, group=group, dst=root)) return req def my_gather(self, rank, size, group, sendbuf, recvbuf, root): if rank == root: for idx in range(size): if idx != rank: dist.recv(recvbuf[idx], src=idx, group=group) else: recvbuf[rank] = sendbuf else: dist.send(sendbuf, group=group, dst=root) def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): # all_start_time = time.time() original_shape = buffer_m.size() if len(original_shape) > 1: buffer_m = torch.flatten(buffer_m) original_size = buffer_m.numel() worker_error_size = worker_error.numel() cupy.cuda.Device(local_rank).use() if original_size != worker_error_size: empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m.add_(worker_error) worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(buffer_m.numel()) worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) if self.bool_not_supported: cupy_sign_list_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool().to(dtype=torch.uint8)), self.size) else: cupy_sign_list_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size) cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale) cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size], dtype=cupy_sign_list_packed[0].dtype) # cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) sign_list_packed = [ self.compression_backend.cupy2torch(cupy_sign_list_packed[idx]) for idx in range(self.size) ] # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale) recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign) #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale) recvbuf_scale = [ torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(get_accelerator().device_name(local_rank))) for i in range(self.size) ] # communication phase 1 # gather_start = time.time() # Alltoall for sign dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) # Allgather for scale dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) # gather_end = time.time() # cupy_sign_list_packed, sign_list_packed, cupy_worker_scale, worker_scale = None, None, None, None cupy_sign_list_packed = None cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign) #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale)) compensated_server_m = self.compression_backend.cupy2torch( (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) compensated_server_m.add_(server_error) server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) # cupy_server_scale = self.compression_backend.torch2cupy(server_scale) if self.bool_not_supported: cupy_server_sign_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool().to(dtype=torch.uint8)), 1) else: cupy_server_sign_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1) compensated_server_m = None cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size], dtype=cupy_recvbuf_sign.dtype) # cupy_recvbuf_sign, recvbuf_sign = None, None cupy_recvbuf_sign = None server_sign_packed = [self.compression_backend.cupy2torch(cupy_server_sign_packed[0])] recvbuf_sign_server = [ self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx]) for idx in range(self.size) ] # server_scale = self.compression_backend.cupy2torch(cupy_server_scale) cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) # cupy_recvbuf_scale, recvbuf_scale = None, None recvbuf_scale_server = [ self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx]) for idx in range(self.size) ] # Communication Phase 2 dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) cupy_server_sign_packed = None # need to convert from a tensor list to a single tensor # dist.all_gather only provides a tensor list as the recv/output buffer recvbuf_sign_server = torch.stack(recvbuf_sign_server) cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(recvbuf_sign_server) buffer_m.data.copy_( self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data) if original_size != worker_error_size: buffer_m = buffer_m[0:original_size] if len(original_shape) > 1: buffer_m = buffer_m.reshape(original_shape) return buffer_m