utils.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import os
  3. from typing import List
  4. import torch
  5. from deepspeed import comm as dist
  6. from deepspeed.utils import logger
  7. from deepspeed.ops.adam import DeepSpeedCPUAdam
  8. from deepspeed.ops.adam import FusedAdam
  9. from deepspeed.utils.nvtx import instrument_w_nvtx
  10. from deepspeed.accelerator import get_accelerator
  11. def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
  12. data_parallel_size = int(dist.get_world_size())
  13. parameter_parallel_size = parameter_parallel_size or data_parallel_size
  14. logger.info("data_parallel_size: %s, parameter_parallel_size: %s",
  15. data_parallel_size,
  16. parameter_parallel_size)
  17. assert data_parallel_size % parameter_parallel_size == 0, \
  18. 'world size should be divisible by parameter parallel size'
  19. rank = dist.get_rank()
  20. my_group = None
  21. for i in range(data_parallel_size // parameter_parallel_size):
  22. ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
  23. group = dist.new_group(ranks)
  24. if rank in ranks:
  25. my_group = group
  26. return my_group
  27. class ZeRORuntimeException(Exception):
  28. pass
  29. ZERO_SUPPORTED_OPTIMIZERS = [
  30. torch.optim.Adam,
  31. torch.optim.AdamW,
  32. FusedAdam,
  33. DeepSpeedCPUAdam
  34. ]
  35. # Add apex FusedAdam to supported list if apex is installed
  36. try:
  37. import apex
  38. if hasattr(apex, 'optimizers') and hasattr(apex.optimizers, 'FusedAdam'):
  39. ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
  40. except ImportError:
  41. pass
  42. def is_zero_supported_optimizer(optimizer):
  43. if dist.get_rank() == 0:
  44. logger.info(
  45. f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
  46. )
  47. return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
  48. def get_lst_from_rank0(lst: List[int]) -> None:
  49. """
  50. NOTE: creates both communication and synchronization overhead so should be used
  51. sparingly
  52. """
  53. lst_tensor = torch.tensor(
  54. lst if dist.get_rank() == 0 else [-1] * len(lst),
  55. dtype=int,
  56. # device=get_accelerator().current_device_name(),
  57. device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
  58. requires_grad=False,
  59. )
  60. dist.broadcast(lst_tensor, src=0, async_op=False)
  61. return list(lst_tensor.cpu().numpy())
  62. @instrument_w_nvtx
  63. def assert_ints_same_as_other_ranks(ints: List[int]) -> None:
  64. """
  65. NOTE: creates both communication and synchronization overhead so should be
  66. used sparingly
  67. takes a list of ints from each rank and ensures that they are the same
  68. across ranks, throwing an exception if they are not.
  69. """
  70. rank0_ints = get_lst_from_rank0(ints)
  71. if ints != rank0_ints:
  72. raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: "
  73. f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}")