test_mpi_backend.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from mpi4py import MPI
  5. import torch
  6. import deepspeed.comm as dist
  7. import numpy as np
  8. import deepspeed
  9. from deepspeed.runtime.comm.mpi import MpiBackend
  10. from deepspeed.accelerator import get_accelerator
  11. comm = MPI.COMM_WORLD
  12. size = comm.Get_size()
  13. rank = comm.Get_rank()
  14. deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
  15. # Change cuda_aware to True to test out CUDA-Aware MPI communication
  16. backend = MpiBackend(cuda_aware=False)
  17. local_rank = rank % get_accelerator().device_count()
  18. device = torch.device(get_accelerator().device_name(), local_rank)
  19. # A simulated compression function using deepspeed.comm
  20. def torch_sim(a):
  21. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  22. scale = a.norm() / np.sqrt(a.numel())
  23. a_compressed = scale * a_sign
  24. a_sign = None
  25. worker_error = a - a_compressed
  26. dist.all_reduce(a_compressed)
  27. a_compressed.mul_(1 / dist.get_world_size())
  28. a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  29. a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
  30. server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
  31. a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
  32. a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
  33. rank = dist.get_rank()
  34. server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
  35. get_accelerator().synchronize()
  36. dist.barrier()
  37. return a_server_compressed, worker_error, server_error
  38. tensor_size = 100 * 2**20
  39. server_size = int(tensor_size / size)
  40. if tensor_size % (8 * size) != 0:
  41. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  42. else:
  43. right_tensor_size = tensor_size
  44. right_server_size = right_tensor_size // size
  45. # Adding bias to the initialization of the gradient we are communicating
  46. # In order to get rid of the case where some elements in the gradient are too small
  47. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  48. worker_error = torch.zeros(right_tensor_size, device=device)
  49. server_error = torch.zeros(right_server_size, device=device)
  50. a_torch, worker_error_torch, server_error_torch = torch_sim(a)
  51. get_accelerator().empty_cache()
  52. a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  53. threshold = 1e-6
  54. magnitude_threshold = 1e-6
  55. diff_mask = (a_after - a_torch) > threshold
  56. diff_server_mask = torch.chunk(diff_mask, size)[rank]
  57. mpi_server = torch.chunk(a_after, size)[rank] + server_error
  58. torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
  59. test_correctness = True
  60. # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
  61. # The test would skip those numbers that are too small in compensated_server_m
  62. if test_correctness:
  63. if torch.sum(diff_server_mask) == 0:
  64. print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
  65. else:
  66. check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
  67. if torch.sum(check_mag_mask) == 0:
  68. print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
  69. else:
  70. print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))