test_compressed_perf.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed.comm as dist
  6. import numpy as np
  7. import argparse
  8. import deepspeed
  9. import os
  10. from deepspeed.runtime.comm.compressed import CompressedBackend
  11. from deepspeed.utils.timer import SynchronizedWallClockTimer
  12. from deepspeed.accelerator import get_accelerator
  13. from statistics import mean
  14. timers = SynchronizedWallClockTimer()
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('--local_rank', type=int, default=-1)
  17. args = parser.parse_args()
  18. deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
  19. args.local_rank = int(os.environ['LOCAL_RANK'])
  20. get_accelerator().set_device(args.local_rank)
  21. device = torch.device(get_accelerator().device_name(), args.local_rank)
  22. size = dist.get_world_size()
  23. rank = dist.get_rank()
  24. backend = CompressedBackend()
  25. local_rank = args.local_rank
  26. # Setting tensor_size (BERT-Large)
  27. tensor_size = 300 * 2**20
  28. server_size = int(tensor_size / size)
  29. if tensor_size % (8 * size) != 0:
  30. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  31. else:
  32. right_tensor_size = tensor_size
  33. right_server_size = right_tensor_size // size
  34. # Adding bias to the initialization of the gradient we are communicating
  35. # In order to get rid of the case where some elements in the gradient are too small
  36. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  37. worker_error = torch.zeros(right_tensor_size, device=device)
  38. server_error = torch.zeros(right_server_size, device=device)
  39. warmup = 10
  40. iters = 10
  41. # Warmup
  42. for i in range(warmup):
  43. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  44. time_list = []
  45. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  46. scale = a.norm() / np.sqrt(a.numel())
  47. a_compressed = scale * a_sign
  48. print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None
  49. for i in range(iters):
  50. timers('compressed_allreduce').start()
  51. backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  52. #deepspeed.comm.all_reduce(a_compressed)
  53. timers('compressed_allreduce').stop()
  54. time_list.append(timers('compressed_allreduce').elapsed())
  55. #timer_names = ['compressed_allreduce']
  56. #timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
  57. places = 2
  58. convert = 1e3
  59. float_size = 4
  60. if rank == 0:
  61. for i in range(iters):
  62. lat = time_list[i]
  63. print("latency = ", lat * convert)
  64. minlat = round(min(time_list) * convert)
  65. maxlat = round(max(time_list) * convert)
  66. meanlat = round(mean(time_list) * convert, places)
  67. print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat)) if rank == 0 else None
  68. #print("tensor shape", a.shape)
  69. duration = meanlat / 1e3
  70. tput = ((tensor_size * 4) / duration)
  71. print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None
  72. size = tensor_size * 4
  73. n = dist.get_world_size()
  74. busbw = (size / duration) * (2 * (n - 1) / n)
  75. print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None