utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from typing import List
  6. import torch
  7. from deepspeed import comm as dist
  8. from deepspeed.utils import logger
  9. from deepspeed.ops.adam import DeepSpeedCPUAdam
  10. from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
  11. from deepspeed.ops.adam import FusedAdam
  12. from deepspeed.utils.nvtx import instrument_w_nvtx
  13. from deepspeed.accelerator import get_accelerator
  14. def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
  15. data_parallel_size = int(dist.get_world_size())
  16. parameter_parallel_size = parameter_parallel_size or data_parallel_size
  17. logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size)
  18. assert data_parallel_size % parameter_parallel_size == 0, \
  19. 'world size should be divisible by parameter parallel size'
  20. rank = dist.get_rank()
  21. my_group = None
  22. for i in range(data_parallel_size // parameter_parallel_size):
  23. ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
  24. group = dist.new_group(ranks)
  25. if rank in ranks:
  26. my_group = group
  27. return my_group
  28. class ZeRORuntimeException(Exception):
  29. pass
  30. ZERO_SUPPORTED_OPTIMIZERS = [
  31. torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad
  32. ]
  33. # Add apex FusedAdam to supported list if apex is installed
  34. try:
  35. import apex
  36. if hasattr(apex, 'optimizers') and hasattr(apex.optimizers, 'FusedAdam'):
  37. ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
  38. except ImportError:
  39. pass
  40. def is_zero_supported_optimizer(optimizer):
  41. if dist.get_rank() == 0:
  42. logger.info(f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}')
  43. return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
  44. def get_lst_from_rank0(lst: List[int]) -> None:
  45. """
  46. NOTE: creates both communication and synchronization overhead so should be used
  47. sparingly
  48. """
  49. lst_tensor = torch.tensor(
  50. lst if dist.get_rank() == 0 else [-1] * len(lst),
  51. dtype=int,
  52. # device=get_accelerator().current_device_name(),
  53. device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
  54. requires_grad=False,
  55. )
  56. dist.broadcast(lst_tensor, src=0, async_op=False)
  57. return list(lst_tensor.cpu().numpy())
  58. @instrument_w_nvtx
  59. def assert_ints_same_as_other_ranks(ints: List[int]) -> None:
  60. """
  61. NOTE: creates both communication and synchronization overhead so should be
  62. used sparingly
  63. takes a list of ints from each rank and ensures that they are the same
  64. across ranks, throwing an exception if they are not.
  65. """
  66. rank0_ints = get_lst_from_rank0(ints)
  67. if ints != rank0_ints:
  68. raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: "
  69. f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}")