mics_utils.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  5. # SPDX-License-Identifier: Apache-2.0
  6. import os
  7. from dataclasses import dataclass
  8. from typing import List
  9. import numpy as np
  10. import torch
  11. from torch import Tensor
  12. from deepspeed import comm as dist
  13. from deepspeed.accelerator import get_accelerator
  14. from deepspeed.utils import logger
  15. def _log_rank0(msg):
  16. if dist.get_rank() == 0:
  17. logger.info(msg)
  18. @torch.jit.script
  19. def scale_tensors(tensors: List[Tensor], scale: int):
  20. for t in tensors:
  21. t.div_(scale)
  22. @dataclass
  23. class MiCS_CommGroups:
  24. """"""
  25. param_shard_group = None
  26. param_shard_size = -1
  27. param_shard_rank = -1
  28. param_repli_group = None
  29. param_repli_size = -1
  30. param_repli_rank = -1
  31. param_intra_node_group = None
  32. param_inter_node_shard_group = None
  33. def create_mics_comm_groups(
  34. shard_size,
  35. dp_group,
  36. hierarchical_allgather=False,
  37. mpu=None,
  38. ):
  39. """
  40. create shard-group, replicate-group from config_file
  41. TODO: consider broadcast the config from rank0
  42. Returns:
  43. MiCS_CommGroups
  44. """
  45. # env var for debugging purpose
  46. ndevices_per_node = int(os.environ.get("NDEV_PER_NODE", get_accelerator().device_count()))
  47. _log_rank0(f'creating MiCS communication groups with per node device size {ndevices_per_node}')
  48. groups = MiCS_CommGroups()
  49. if mpu is not None:
  50. assert dp_group == mpu.get_data_parallel_group()
  51. # full size of the world
  52. world_size = dist.get_world_size()
  53. # global rank
  54. global_rank = dist.get_rank()
  55. config = _generate_mics_config(world_size, ndevices_per_node, shard_size, 1)
  56. ranks_of_shard_group = config['shard_groups']
  57. ranks_of_repli_group = config['replicate_groups']
  58. if len(ranks_of_repli_group) == 0:
  59. assert len(ranks_of_shard_group) == 1, "replicate groups are empty only for single shard group"
  60. for r in ranks_of_shard_group[0]:
  61. ranks_of_repli_group.append([r])
  62. # for simplicity
  63. assert _sizes_all_same(ranks_of_repli_group), "replicate groups must have the same size"
  64. assert _sizes_all_same(ranks_of_shard_group), "shard groups must have the same size"
  65. assert sum([len(g) for g in ranks_of_shard_group]) == dist.get_world_size(), "all sharded ranks "
  66. if len(ranks_of_shard_group) > 1: # if only shard on one group then no need for replicate groups
  67. assert len(ranks_of_shard_group) == len(
  68. ranks_of_repli_group[0]), "number of shard groups must equal to the size of each replicate group"
  69. global_rank = dist.get_rank()
  70. # create shard groups
  71. for shard_ranks in ranks_of_shard_group:
  72. _group = dist.new_group(shard_ranks)
  73. if global_rank in shard_ranks:
  74. groups.param_shard_group = _group
  75. groups.param_shard_size = len(shard_ranks)
  76. groups.param_shard_rank = dist.get_rank(_group)
  77. logger.info(f'rank {global_rank}, shard group'
  78. f' {groups.param_shard_rank}/{dist.get_world_size(group=_group)}')
  79. # create replicate groups
  80. for repli_ranks in ranks_of_repli_group:
  81. if len(repli_ranks) > 1:
  82. _group = dist.new_group(repli_ranks)
  83. if global_rank in repli_ranks:
  84. groups.param_repli_group = _group
  85. groups.param_repli_size = len(repli_ranks)
  86. groups.param_repli_rank = dist.get_rank(group=_group)
  87. logger.info(f'rank {global_rank} '
  88. f'replicate group {groups.param_repli_rank}/{dist.get_world_size(group=_group)}')
  89. else:
  90. groups.param_repli_group = None
  91. groups.param_repli_size = 1
  92. groups.param_repli_rank = 0
  93. logger.info(f'rank {global_rank} replicate group 0/1')
  94. # assign shard group size as world size
  95. assert groups.param_shard_size == len(ranks_of_shard_group[0])
  96. if hierarchical_allgather:
  97. # create hierarchy inter-node, intra-node groups
  98. # n_span_nodes = config['shard_span']
  99. n_span_nodes = config['span_nodes']
  100. assert n_span_nodes > 1, "sharding spans on single node, no need for hierarchy allgather"
  101. assert len(ranks_of_shard_group[0]) % n_span_nodes == 0
  102. n_gpu_per_node = len(ranks_of_shard_group[0]) // n_span_nodes
  103. intra_node_ranks_group = []
  104. inter_node_ranks_group = []
  105. for shard_group in ranks_of_shard_group:
  106. _intra_node_ranks = []
  107. for i in range(0, len(shard_group), n_gpu_per_node):
  108. _intra_node_ranks.append(shard_group[i:i + n_gpu_per_node])
  109. _inter_node_ranks = []
  110. for i in range(n_gpu_per_node):
  111. _ranks = [_g[i] for _g in _intra_node_ranks]
  112. _inter_node_ranks.append(_ranks)
  113. intra_node_ranks_group.append(_intra_node_ranks)
  114. inter_node_ranks_group.append(_inter_node_ranks)
  115. _log_rank0(f"create for hierarchy all-gather groups: intra nodes {intra_node_ranks_group}")
  116. _log_rank0(f"create for hierarchy all-gather groups: inter nodes {inter_node_ranks_group}")
  117. # create communicators
  118. for shard_group in intra_node_ranks_group:
  119. for intra_node_ranks in shard_group:
  120. _group = dist.new_group(intra_node_ranks)
  121. if global_rank in intra_node_ranks:
  122. groups.param_intra_node_group = _group
  123. _log_rank0(f'create group for intra node ranks {intra_node_ranks}')
  124. for shard_group in inter_node_ranks_group:
  125. for inter_node_ranks in shard_group:
  126. _group = dist.new_group(inter_node_ranks)
  127. if global_rank in inter_node_ranks:
  128. groups.param_inter_node_shard_group = _group
  129. _log_rank0(f'create group for inter node ranks {inter_node_ranks}')
  130. return groups
  131. def _generate_mics_config(world_size, ndev_per_node, shard_size, pp_size=1):
  132. """Generating the configuration for sharding This shard config generation assume
  133. that the pipeline stages are partitioned in order, i.e., first ranks
  134. hold the stage0, etc.
  135. Args:
  136. shard_size (int): zero3 data-parallel shard size, FIXME:
  137. change the name later
  138. pp_size (int): pipeline parallel size, currently, only work with
  139. pipeline parallelism + zero
  140. """
  141. assert world_size % pp_size == 0
  142. assert (world_size // pp_size) % shard_size == 0, \
  143. f"dp group size is not dividable by dp_shard_size, "\
  144. f" (world_size {world_size}, pp_size {pp_size}, dp_shard_size {shard_size})"
  145. config = {}
  146. shard_groups = np.arange(world_size).reshape(-1, shard_size)
  147. replicate_groups = []
  148. for i in range(shard_size):
  149. same_shard_ranks = shard_groups[:, i].tolist()
  150. n_ranks = len(same_shard_ranks)
  151. replicate_size = n_ranks // pp_size
  152. replicate_groups.extend([same_shard_ranks[j:j + replicate_size] for j in range(0, n_ranks, replicate_size)])
  153. config['replicate_groups'] = replicate_groups
  154. config['shard_groups'] = shard_groups.tolist()
  155. config["span_nodes"] = len(shard_groups[0]) // ndev_per_node
  156. return config
  157. def _sizes_all_same(groups):
  158. """all groups have same length"""
  159. all_same = True
  160. for g in groups:
  161. if len(g) != len(groups[0]):
  162. return False
  163. return all_same