test_nccl_perf.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import time
  2. import torch
  3. import torch.distributed as dist
  4. import numpy as np
  5. import argparse
  6. import deepspeed
  7. import os
  8. from deepspeed.runtime.comm.nccl import NcclBackend
  9. from deepspeed.utils.timer import SynchronizedWallClockTimer
  10. from statistics import mean
  11. timers = SynchronizedWallClockTimer()
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('--local_rank', type=int, default=-1)
  14. args = parser.parse_args()
  15. deepspeed.init_distributed(dist_backend='nccl')
  16. args.local_rank = int(os.environ['LOCAL_RANK'])
  17. torch.cuda.set_device(args.local_rank)
  18. device = torch.device("cuda", args.local_rank)
  19. size = dist.get_world_size()
  20. rank = dist.get_rank()
  21. backend = NcclBackend()
  22. local_rank = args.local_rank
  23. # Setting tensor_size (BERT-Large)
  24. tensor_size = 300 * 2**20
  25. server_size = int(tensor_size / size)
  26. if tensor_size % (8 * size) != 0:
  27. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  28. else:
  29. right_tensor_size = tensor_size
  30. right_server_size = right_tensor_size // size
  31. # Adding bias to the initialization of the gradient we are communicating
  32. # In order to get rid of the case where some elements in the gradient are too small
  33. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  34. worker_error = torch.zeros(right_tensor_size, device=device)
  35. server_error = torch.zeros(right_server_size, device=device)
  36. warmup = 10
  37. iters = 10
  38. # Warmup
  39. for i in range(warmup):
  40. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  41. time_list = []
  42. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  43. scale = a.norm() / np.sqrt(a.numel())
  44. a_compressed = scale * a_sign
  45. print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None
  46. for i in range(iters):
  47. timers('compressed_allreduce').start()
  48. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  49. #torch.distributed.all_reduce(a_compressed)
  50. timers('compressed_allreduce').stop()
  51. time_list.append(timers('compressed_allreduce').elapsed())
  52. #timer_names = ['compressed_allreduce']
  53. #timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
  54. places = 2
  55. convert = 1e3
  56. float_size = 4
  57. if rank == 0:
  58. for i in range(iters):
  59. lat = time_list[i]
  60. print("latency = ", lat * convert)
  61. minlat = round(min(time_list) * convert)
  62. maxlat = round(max(time_list) * convert)
  63. meanlat = round(mean(time_list) * convert, places)
  64. print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat,
  65. maxlat,
  66. meanlat)) if rank == 0 else None
  67. #print("tensor shape", a.shape)
  68. duration = meanlat / 1e3
  69. tput = ((tensor_size * 4) / duration)
  70. print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None
  71. size = tensor_size * 4
  72. n = dist.get_world_size()
  73. busbw = (size / duration) * (2 * (n - 1) / n)
  74. print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None