utils.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from typing import List
  6. import torch
  7. from deepspeed import comm as dist
  8. from deepspeed.utils import logger
  9. from deepspeed.ops.adam import DeepSpeedCPUAdam
  10. from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
  11. from deepspeed.ops.adam import FusedAdam
  12. from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion
  13. from deepspeed.utils.nvtx import instrument_w_nvtx
  14. from deepspeed.accelerator import get_accelerator
  15. def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
  16. data_parallel_size = int(dist.get_world_size())
  17. parameter_parallel_size = parameter_parallel_size or data_parallel_size
  18. logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size)
  19. assert data_parallel_size % parameter_parallel_size == 0, \
  20. 'world size should be divisible by parameter parallel size'
  21. rank = dist.get_rank()
  22. my_group = None
  23. for i in range(data_parallel_size // parameter_parallel_size):
  24. ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
  25. group = dist.new_group(ranks)
  26. if rank in ranks:
  27. my_group = group
  28. return my_group
  29. class ZeRORuntimeException(Exception):
  30. pass
  31. ZERO_SUPPORTED_OPTIMIZERS = [
  32. torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad,
  33. DeepSpeedCPULion, FusedLion
  34. ]
  35. # Add apex FusedAdam to supported list if apex is installed
  36. try:
  37. import apex
  38. if hasattr(apex, 'optimizers') and hasattr(apex.optimizers, 'FusedAdam'):
  39. ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
  40. except ImportError:
  41. pass
  42. def is_zero_supported_optimizer(optimizer):
  43. if dist.get_rank() == 0:
  44. logger.info(f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}')
  45. return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
  46. def get_lst_from_rank0(lst: List[int]) -> None:
  47. """
  48. NOTE: creates both communication and synchronization overhead so should be used
  49. sparingly
  50. """
  51. lst_tensor = torch.tensor(
  52. lst if dist.get_rank() == 0 else [-1] * len(lst),
  53. dtype=int,
  54. # device=get_accelerator().current_device_name(),
  55. device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
  56. requires_grad=False,
  57. )
  58. dist.broadcast(lst_tensor, src=0, async_op=False)
  59. return list(lst_tensor.cpu().numpy())
  60. @instrument_w_nvtx
  61. def assert_ints_same_as_other_ranks(ints: List[int]) -> None:
  62. """
  63. NOTE: creates both communication and synchronization overhead so should be
  64. used sparingly
  65. takes a list of ints from each rank and ensures that they are the same
  66. across ranks, throwing an exception if they are not.
  67. """
  68. rank0_ints = get_lst_from_rank0(ints)
  69. if ints != rank0_ints:
  70. raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: "
  71. f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}")
  72. def is_builtin_type(obj):
  73. # https://stackoverflow.com/a/17795199
  74. return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"
  75. def isinstance_namedtuple(obj: object) -> bool:
  76. """
  77. Is this an instance of namedtuple/NamedTuple?
  78. From: https://stackoverflow.com/a/62692640
  79. Args:
  80. obj (object): An object.
  81. Returns:
  82. bool: True if namedtuple/NamedTuple else False.
  83. """
  84. return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields')
  85. def is_zero_param(parameter):
  86. if not torch.is_tensor(parameter):
  87. return False
  88. return hasattr(parameter, 'ds_id')
  89. def apply_to_tensors_only(function, value, warning_msg_fn=None):
  90. """
  91. Apply `function` to every Tensor in `value`.
  92. Args:
  93. functional: The function class to apply.
  94. value (Any): Target object to apply `function` to.
  95. Returns:
  96. Any: Output of `function`.
  97. """
  98. if isinstance(value, (tuple, list)):
  99. touched_outputs = []
  100. for elem in value:
  101. touched_output = apply_to_tensors_only(function, elem)
  102. touched_outputs.append(touched_output)
  103. if isinstance_namedtuple(value):
  104. # namedtuples require a slightly different syntax.
  105. return value.__class__(*touched_outputs)
  106. return value.__class__(touched_outputs)
  107. elif isinstance(value, dict):
  108. # apply inplace to avoid recreating dict inherited objects
  109. for key in value.keys():
  110. value[key] = apply_to_tensors_only(function, value[key])
  111. return value
  112. elif isinstance(value, torch.Tensor):
  113. # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter
  114. touched_output = function(value)
  115. # restore zero param attributes if those get stripped by `backward_function`
  116. if not is_zero_param(touched_output) and is_zero_param(value):
  117. touched_output.ds_param_alias = value
  118. return touched_output
  119. else:
  120. if not is_builtin_type(value):
  121. global warned
  122. if warning_msg_fn and not warned and dist.get_rank() == 0:
  123. logger.warning(warning_msg_fn(value))
  124. warned = True
  125. return value