test_mpi_perf.py 2.1 KB

  1. from mpi4py import MPI
  2. import time
  3. import torch
  4. import torch.distributed as dist
  5. import numpy as np
  6. import deepspeed
  7. from deepspeed.runtime.comm.mpi import MpiBackend
  8. # Configure wall clock timer
  9. from deepspeed.utils.timer import SynchronizedWallClockTimer
  10. from statistics import mean
  11. timers = SynchronizedWallClockTimer()
  12. comm = MPI.COMM_WORLD
  13. size = comm.Get_size()
  14. rank = comm.Get_rank()
  15. deepspeed.init_distributed(dist_backend='nccl')
  16. # Change cuda_aware to True to test out CUDA-Aware MPI communication
  17. backend = MpiBackend(cuda_aware=False)
  18. device = torch.device('cuda', rank % torch.cuda.device_count())
  19. tensor_size = 300 * 2**20
  20. server_size = int(tensor_size / size)
  21. if tensor_size % (8 * size) != 0:
  22. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  23. else:
  24. right_tensor_size = tensor_size
  25. right_server_size = right_tensor_size // size
  26. # Adding bias to the initialization of the gradient we are communicating
  27. # In order to get rid of the case where some elements in the gradient are too small
  28. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  29. worker_error = torch.zeros(right_tensor_size, device=device)
  30. server_error = torch.zeros(right_server_size, device=device)
  31. warmup = 10
  32. iters = 10
  33. local_rank = rank % torch.cuda.device_count()
  34. # Warmup
  35. for i in range(warmup):
  36. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  37. time_list = []
  38. for i in range(iters):
  39. timers('compressed_allreduce').start()
  40. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  41. timers('compressed_allreduce').stop()
  42. time_list.append(timers('compressed_allreduce').elapsed())
  43. timer_names = ['compressed_allreduce']
  44. timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
  45. places = 2
  46. convert = 1e3
  47. float_size = 4
  48. if rank == 0:
  49. for i in range(iters):
  50. lat = time_list[i]
  51. print("latency = ", lat * convert)
  52. minlat = round(min(time_list) * convert)
  53. maxlat = round(max(time_list) * convert)
  54. meanlat = round(mean(time_list) * convert, places)
  55. print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))