test_compressed_backend.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed.comm as dist
  6. import numpy as np
  7. import argparse
  8. import deepspeed
  9. import os
  10. from deepspeed.runtime.comm.compressed import CompressedBackend
  11. from deepspeed.accelerator import get_accelerator
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('--local_rank', type=int, default=-1)
  14. args = parser.parse_args()
  15. deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
  16. args.local_rank = int(os.environ['LOCAL_RANK'])
  17. get_accelerator().set_device(args.local_rank)
  18. device = torch.device(get_accelerator().device_name(), args.local_rank)
  19. size = dist.get_world_size()
  20. rank = dist.get_rank()
  21. backend = CompressedBackend()
  22. local_rank = args.local_rank
  23. # A simulated compression function using deepspeed.comm
  24. def torch_sim(a):
  25. a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  26. scale = a.norm() / np.sqrt(a.numel())
  27. a_compressed = scale * a_sign
  28. a_sign = None
  29. worker_error = a - a_compressed
  30. dist.all_reduce(a_compressed)
  31. a_compressed.mul_(1 / dist.get_world_size())
  32. a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
  33. a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
  34. server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
  35. a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
  36. a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
  37. rank = dist.get_rank()
  38. server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
  39. get_accelerator().synchronize()
  40. dist.barrier()
  41. return a_server_compressed, worker_error, server_error
  42. tensor_size = 300 * 2**20
  43. server_size = int(tensor_size / size)
  44. if tensor_size % (8 * size) != 0:
  45. right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
  46. else:
  47. right_tensor_size = tensor_size
  48. right_server_size = right_tensor_size // size
  49. # Adding bias to the initialization of the gradient we are communicating
  50. # In order to get rid of the case where some elements in the gradient are too small
  51. a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
  52. worker_error = torch.zeros(right_tensor_size, device=device)
  53. server_error = torch.zeros(right_server_size, device=device)
  54. a_torch, worker_error_torch, server_error_torch = torch_sim(a)
  55. get_accelerator().empty_cache()
  56. a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
  57. print(a_torch.cpu())
  58. print(a_after.cpu())
  59. threshold = 1e-6
  60. magnitude_threshold = 1e-6
  61. diff_mask = (a_after - a_torch) > threshold
  62. diff_server_mask = torch.chunk(diff_mask, size)[rank]
  63. mpi_server = torch.chunk(a_after, size)[rank] + server_error
  64. torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
  65. test_correctness = True
  66. # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
  67. # The test would skip those numbers that are too small in compensated_server_m
  68. if test_correctness:
  69. if torch.sum(diff_server_mask) == 0:
  70. print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
  71. else:
  72. check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
  73. if torch.sum(check_mag_mask) == 0:
  74. print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
  75. else:
  76. print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))