test_mpi_perf.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from mpi4py import MPI
  2. import time
  3. import torch
  4. import torch.distributed as dist
  5. import numpy as np
  6. import deepspeed
  7. from deepspeed.runtime.comm.mpi import MpiBackend
  8. # Configure wall clock timer
  9. from deepspeed.utils.timer import SynchronizedWallClockTimer
  10. from statistics import mean
  11. timers = SynchronizedWallClockTimer()
  12. comm = MPI.COMM_WORLD
  13. size = comm.Get_size()
  14. rank = comm.Get_rank()
  15. deepspeed.init_distributed(dist_backend='nccl')
  16. # Change cuda_aware to True to test out CUDA-Aware MPI communication
  17. backend = MpiBackend(cuda_aware=False)
  18. device = torch.device('cuda', rank % torch.cuda.device_count())
  19. tensor_size = 300 * 2**20
  20. server_size = int(tensor_size / size)
  21. if tensor_size % (8 * size) != 0:
  22. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  23. else:
  24. right_tensor_size = tensor_size
  25. right_server_size = right_tensor_size // size
  26. # Adding bias to the initialization of the gradient we are communicating
  27. # In order to get rid of the case where some elements in the gradient are too small
  28. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  29. worker_error = torch.zeros(right_tensor_size, device=device)
  30. server_error = torch.zeros(right_server_size, device=device)
  31. warmup = 10
  32. iters = 10
  33. local_rank = rank % torch.cuda.device_count()
  34. # Warmup
  35. for i in range(warmup):
  36. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  37. time_list = []
  38. for i in range(iters):
  39. timers('compressed_allreduce').start()
  40. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  41. timers('compressed_allreduce').stop()
  42. time_list.append(timers('compressed_allreduce').elapsed())
  43. timer_names = ['compressed_allreduce']
  44. timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
  45. places = 2
  46. convert = 1e3
  47. float_size = 4
  48. if rank == 0:
  49. for i in range(iters):
  50. lat = time_list[i]
  51. print("latency = ", lat * convert)
  52. minlat = round(min(time_list) * convert)
  53. maxlat = round(max(time_list) * convert)
  54. meanlat = round(mean(time_list) * convert, places)
  55. print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))