test_mpi_perf.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from mpi4py import MPI
  5. import torch
  6. import deepspeed
  7. from deepspeed.runtime.comm.mpi import MpiBackend
  8. # Configure wall clock timer
  9. from deepspeed.utils.timer import SynchronizedWallClockTimer
  10. from deepspeed.accelerator import get_accelerator
  11. from statistics import mean
  12. timers = SynchronizedWallClockTimer()
  13. comm = MPI.COMM_WORLD
  14. size = comm.Get_size()
  15. rank = comm.Get_rank()
  16. deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
  17. # Change cuda_aware to True to test out CUDA-Aware MPI communication
  18. backend = MpiBackend(cuda_aware=False)
  19. local_rank = rank % get_accelerator().device_count()
  20. device = torch.device(get_accelerator().device_name(), local_rank)
  21. tensor_size = 300 * 2**20
  22. server_size = int(tensor_size / size)
  23. if tensor_size % (8 * size) != 0:
  24. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  25. else:
  26. right_tensor_size = tensor_size
  27. right_server_size = right_tensor_size // size
  28. # Adding bias to the initialization of the gradient we are communicating
  29. # In order to get rid of the case where some elements in the gradient are too small
  30. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  31. worker_error = torch.zeros(right_tensor_size, device=device)
  32. server_error = torch.zeros(right_server_size, device=device)
  33. warmup = 10
  34. iters = 10
  35. # Warmup
  36. for i in range(warmup):
  37. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  38. time_list = []
  39. for i in range(iters):
  40. timers('compressed_allreduce').start()
  41. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  42. timers('compressed_allreduce').stop()
  43. time_list.append(timers('compressed_allreduce').elapsed())
  44. timer_names = ['compressed_allreduce']
  45. timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
  46. places = 2
  47. convert = 1e3
  48. float_size = 4
  49. if rank == 0:
  50. for i in range(iters):
  51. lat = time_list[i]
  52. print("latency = ", lat * convert)
  53. minlat = round(min(time_list) * convert)
  54. maxlat = round(max(time_list) * convert)
  55. meanlat = round(mean(time_list) * convert, places)
  56. print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))