test_nccl_backend.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import time
  2. import torch
  3. import torch.distributed as dist
  4. import numpy as np
  5. import argparse
  6. import deepspeed
  7. import os
  8. from deepspeed.runtime.comm.nccl import NcclBackend
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument('--local_rank', type=int, default=-1)
  11. args = parser.parse_args()
  12. deepspeed.init_distributed(dist_backend='nccl')
  13. args.local_rank = int(os.environ['LOCAL_RANK'])
  14. torch.cuda.set_device(args.local_rank)
  15. device = torch.device("cuda", args.local_rank)
  16. size = dist.get_world_size()
  17. rank = dist.get_rank()
  18. backend = NcclBackend()
  19. local_rank = args.local_rank
  20. # A simulated compression function using torch.distributed
  21. def torch_sim(a):
  22. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  23. scale = a.norm() / np.sqrt(a.numel())
  24. a_compressed = scale * a_sign
  25. a_sign = None
  26. worker_error = a - a_compressed
  27. dist.all_reduce(a_compressed)
  28. a_compressed.mul_(1 / dist.get_world_size())
  29. a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  30. a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
  31. server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
  32. a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
  33. a_server_compressed = torch.cat(
  34. [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
  35. rank = dist.get_rank()
  36. server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
  37. torch.cuda.synchronize()
  38. torch.distributed.barrier()
  39. return a_server_compressed, worker_error, server_error
  40. tensor_size = 300 * 2**20
  41. server_size = int(tensor_size / size)
  42. if tensor_size % (8 * size) != 0:
  43. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  44. else:
  45. right_tensor_size = tensor_size
  46. right_server_size = right_tensor_size // size
  47. # Adding bias to the initialization of the gradient we are communicating
  48. # In order to get rid of the case where some elements in the gradient are too small
  49. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  50. worker_error = torch.zeros(right_tensor_size, device=device)
  51. server_error = torch.zeros(right_server_size, device=device)
  52. a_torch, worker_error_torch, server_error_torch = torch_sim(a)
  53. torch.cuda.empty_cache()
  54. a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  55. threshold = 1e-6
  56. magnitude_threshold = 1e-6
  57. diff_mask = (a_after - a_torch) > threshold
  58. diff_server_mask = torch.chunk(diff_mask, size)[rank]
  59. mpi_server = torch.chunk(a_after, size)[rank] + server_error
  60. torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
  61. test_correctness = True
  62. # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
  63. # The test would skip those numbers that are too small in compensated_server_m
  64. if test_correctness:
  65. if torch.sum(diff_server_mask) == 0:
  66. print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
  67. else:
  68. check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
  69. if torch.sum(check_mag_mask) == 0:
  70. print(
  71. 'Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
  72. else:
  73. print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))