# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch import cupy import time import numpy as np from mpi4py import MPI from deepspeed.runtime.compression.cupy import CupyBackend class MpiBackend(object): def __init__(self, cuda_aware): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() self.cuda_aware = cuda_aware self.compression_backend = CupyBackend() def my_igather(self, rank, size, comm, sendbuf, recbuf, root): req = [] if rank == root: for idx in range(size): if idx != rank: req.append(comm.Irecv(recbuf[idx], source=idx)) else: recbuf[rank] = sendbuf else: req.append(comm.Isend(sendbuf, dest=root)) return req def gather_cuda(self, rank, world_size, comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale): # We do in-place operations on cupy buffers so we do not return any buffers requests = [] for idx in range(world_size): req_sign = self.my_igather(rank, world_size, comm, cupy_sign_list_packed[idx], cupy_recvbuf_sign, root=idx) requests += req_sign for idx in range(world_size): req_scale = self.my_igather(rank, world_size, comm, cupy_worker_scale, cupy_recvbuf_scale, root=idx) requests += req_scale MPI.Request.Waitall(requests) def gather_host(self, rank, world_size, comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale): # In-place operations are not possible for newly created cupy arrays # so we need to return the new buffers numpy_recvbuf_sign = np.zeros([world_size, cupy_sign_list_packed[rank].size], dtype=cupy_sign_list_packed[0].dtype) numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype) # 1. convert from cupy to numpy numpy_sign_list_packed = cupy_sign_list_packed for idx in range(world_size): numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx]) numpy_worker_scale = cupy.asnumpy(cupy_worker_scale) numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale) cupy.cuda.get_current_stream().synchronize() # 2. use numpy buffers for communication requests = [] for idx in range(world_size): req_sign = self.my_igather(rank, world_size, comm, numpy_sign_list_packed[idx], numpy_recvbuf_sign, root=idx) requests += req_sign for idx in range(world_size): req_scale = self.my_igather(rank, world_size, comm, numpy_worker_scale, numpy_recvbuf_scale, root=idx) requests += req_scale MPI.Request.Waitall(requests) # 3. Convert back from numpy to cupy cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign) for idx in range(world_size): cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx]) cupy_worker_scale = cupy.asarray(numpy_worker_scale) cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale) cupy.cuda.get_current_stream().synchronize() return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale def allgather_cuda(self, comm, cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server): comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server) comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server) def allgather_host(self, comm, cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server): # 1. Convert cupy to numpy numpy_recvbuf_sign_server = np.zeros([comm.Get_size(), cupy_server_sign_packed.size], dtype=cupy_server_sign_packed.dtype) numpy_recvbuf_scale_server = np.zeros([comm.Get_size(), 1], dtype=cupy_server_scale.dtype) numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed) numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server) numpy_server_scale = cupy.asnumpy(cupy_server_scale) numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server) cupy.cuda.get_current_stream().synchronize() # 2. Communicate numpy buffers comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server) comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server) comm.Barrier() # 3. Convert numpy back to cupy cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed) cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server) cupy_server_scale = cupy.asarray(numpy_server_scale) cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server) cupy.cuda.get_current_stream().synchronize() return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): all_start_time = time.time() original_shape = buffer_m.size() if len(original_shape) > 1: buffer_m = torch.flatten(buffer_m) original_size = buffer_m.numel() worker_error_size = worker_error.numel() cupy.cuda.Device(local_rank).use() if original_size != worker_error_size: empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m.add_(worker_error) worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) cupy_sign_list_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size) cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale) cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size], dtype=cupy_sign_list_packed[0].dtype) cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) # Communication Phase 1 gather_start = time.time() if self.cuda_aware: self.gather_cuda(self.rank, self.size, self.comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale) else: _, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank, self.size, self.comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale) gather_end = time.time() # cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None cupy_sign_list_packed = None compensated_server_m = self.compression_backend.cupy2torch( (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(1 / self.size)).sum(0) compensated_server_m.add_(server_error) server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) cupy_server_scale = self.compression_backend.torch2cupy(server_scale) cupy_server_sign_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1) compensated_server_m = None cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size], dtype=cupy_recvbuf_sign.dtype) cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_recvbuf_scale.dtype) # cupy_recvbuf_sign, cupy_recvbuf_scale = None, None cupy_recvbuf_sign = None # Communication Phase 2 if self.cuda_aware: self.allgather_cuda(self.comm, cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server) else: _, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host( self.comm, cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server) # cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None cupy_server_sign_packed = None buffer_m.data.copy_( self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data) if original_size != worker_error_size: buffer_m = buffer_m[0:original_size] if len(original_shape) > 1: buffer_m = buffer_m.reshape(original_shape) # cupy_recvbuf_sign_server, cupy_recvbuf_scale_server = None, None return buffer_m