test_com_reduce_host.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from mpi4py import MPI
  2. import time
  3. import torch
  4. import torch.distributed as dist
  5. import numpy as np
  6. import deepspeed
  7. from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
  8. comm = MPI.COMM_WORLD
  9. size = comm.Get_size()
  10. rank = comm.Get_rank()
  11. #TODO: Detect the hostname we are running on automatically
  12. torch.distributed.init_process_group(backend='nccl',
  13. init_method='tcp://worker-1:2245',
  14. world_size=size,
  15. rank=rank)
  16. dummy_model = [torch.nn.Parameter(torch.ones(10))]
  17. # Set cuda_aware to False to use host buffers for communication
  18. dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)
  19. device = torch.device('cuda', rank % torch.cuda.device_count())
  20. def torch_sim(a):
  21. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  22. scale = a.norm() / np.sqrt(a.numel())
  23. a_compressed = scale * a_sign
  24. a_sign = None
  25. worker_error = a - a_compressed
  26. dist.all_reduce(a_compressed)
  27. a_compressed.mul_(1 / dist.get_world_size())
  28. a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  29. a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
  30. server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
  31. a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
  32. a_server_compressed = torch.cat(
  33. [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
  34. rank = dist.get_rank()
  35. server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
  36. torch.cuda.synchronize()
  37. torch.distributed.barrier()
  38. return a_server_compressed, worker_error, server_error
  39. tensor_size = 100 * 2**20
  40. server_size = int(tensor_size / size)
  41. if tensor_size % (8 * size) != 0:
  42. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  43. else:
  44. right_tensor_size = tensor_size
  45. right_server_size = right_tensor_size // size
  46. # Adding bias to the initialization of the gradient we are communicating
  47. # In order to get rid of the case where some elements in the gradient are too small
  48. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  49. worker_error = torch.zeros(right_tensor_size, device=device)
  50. server_error = torch.zeros(right_server_size, device=device)
  51. a_torch, worker_error_torch, server_error_torch = torch_sim(a)
  52. torch.cuda.empty_cache()
  53. local_rank = rank % torch.cuda.device_count()
  54. a_after = dummy_optim.Compressed_Allreduce(a,
  55. worker_error,
  56. server_error,
  57. rank,
  58. size,
  59. comm,
  60. local_rank)
  61. threshold = 1e-6
  62. magnitude_threshold = 1e-6
  63. diff_mask = (a_after - a_torch) > threshold
  64. diff_server_mask = torch.chunk(diff_mask, size)[rank]
  65. mpi_server = torch.chunk(a_after, size)[rank] + server_error
  66. torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
  67. # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
  68. # The test would skip those numbers that are too small in compensated_server_m
  69. if torch.sum(diff_server_mask) == 0:
  70. print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
  71. else:
  72. check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
  73. if torch.sum(check_mag_mask) == 0:
  74. print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
  75. else:
  76. print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))