test_server_error.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from mpi4py import MPI
  2. import time
  3. import torch
  4. import torch.distributed as dist
  5. import numpy as np
  6. import deepspeed
  7. from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
  8. comm = MPI.COMM_WORLD
  9. size = comm.Get_size()
  10. rank = comm.Get_rank()
  11. torch.distributed.init_process_group(backend='nccl',
  12. init_method='tcp://worker-0:2245',
  13. world_size=size,
  14. rank=rank)
  15. dummy_model = [torch.nn.Parameter(torch.ones(10))]
  16. dummy_optim = OnebitAdam(dummy_model, cuda_aware=False)
  17. device = torch.device('cuda', rank % torch.cuda.device_count())
  18. def torch_sim(a):
  19. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  20. scale = a.norm() / np.sqrt(a.numel())
  21. a_compressed = scale * a_sign
  22. a_sign = None
  23. worker_error = a - a_compressed
  24. dist.all_reduce(a_compressed)
  25. a_compressed.mul_(1 / dist.get_world_size())
  26. a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  27. a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
  28. server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
  29. a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
  30. a_server_compressed = torch.cat(
  31. [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
  32. rank = dist.get_rank()
  33. server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
  34. torch.cuda.synchronize()
  35. torch.distributed.barrier()
  36. return a_server_compressed, worker_error, server_error
  37. # Input Tensor size
  38. tensor_size = 100 * 2**20
  39. server_size = int(tensor_size / size)
  40. if tensor_size % (8 * size) != 0:
  41. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  42. else:
  43. right_tensor_size = tensor_size
  44. right_server_size = right_tensor_size // size
  45. # The -0.5 is required for avoiding sign flips/errors
  46. a = torch.rand(tensor_size, device=device) - 0.5
  47. worker_error = torch.zeros(right_tensor_size, device=device)
  48. server_error = torch.zeros(right_server_size, device=device)
  49. a_torch, worker_error_torch, server_error_torch = torch_sim(a)
  50. torch.cuda.empty_cache()
  51. local_rank = rank % torch.cuda.device_count()
  52. # Test the 1-bit Adam optimizer
  53. a_after = dummy_optim.Compressed_Allreduce(a,
  54. worker_error,
  55. server_error,
  56. rank,
  57. size,
  58. comm,
  59. local_rank)
  60. # If the error is below the threshold, it is acceptable for training
  61. threshold = 1e-6
  62. diff_pos = ((a_after - a_torch) > threshold)
  63. if rank == 0:
  64. before_diff = torch.chunk(a_after - a_torch,
  65. size)[rank] + server_error - server_error_torch
  66. if torch.norm(before_diff) / torch.norm(torch.chunk(a_after,
  67. size)[rank]) < threshold:
  68. print('Successfully passed the test')
  69. else:
  70. print('The difference for the tensor before allgather is {}'.format(
  71. torch.norm(before_diff)))