groups.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. # The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file
  5. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """
  19. Support different forms of parallelism in DeepSpeed using multiple process groups.
  20. Given that there are multiple scenarios and use-cases, this file is going to be updated
  21. frequently. For now, the group creation needed for the training scenario is being implemented.
  22. For inference and other new scenarios, the code will be either reused or added to this file.
  23. """
  24. from deepspeed import comm as dist
  25. from deepspeed.utils import log_dist
  26. from deepspeed.utils.exceptions import DeprecatedException
  27. # Expert parallel group that the current rank belongs to.
  28. _EXPERT_PARALLEL_GROUP = {}
  29. # Expert data parallel group that the current rank belongs to.
  30. _EXPERT_DATA_PARALLEL_GROUP = {}
  31. # dist world group needs to be cloned for some cases
  32. _WORLD_GROUP = None
  33. # global object to maintain mpu object if passed by a Megatron client
  34. mpu = None
  35. # global object that stores tensor parallel world size for experts
  36. expert_tensor_parallel_world_size = 1
  37. # Deprecated groups initialize function.
  38. def initialize(ep_size=1, mpu=None):
  39. """ Deprecated function. Retained to inform the users."""
  40. raise DeprecatedException(
  41. "Please do not use the groups.initialize() API as it is deprecated. Instead, pass the desired ep_size to deepspeed.moe.layer.MoE(..,ep_size,..)"
  42. )
  43. def _ensure_divisibility(numerator, denominator):
  44. """Ensure that numerator is divisible by the denominator."""
  45. assert numerator % denominator == 0, '{} is not divisible by {}'.format(
  46. numerator, denominator)
  47. # Not currently used. Helper function to create a model (tensor) parallel group.
  48. def _create_model_parallel(model_parallel_size_):
  49. """
  50. Initialize model data parallel groups.
  51. Arguments:
  52. model_parallel_size: number of GPUs used to parallelize model.
  53. Returns:
  54. Tuple of data parallel group and model parallel group
  55. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  56. use 2 GPUs to parallelize the model. The present function will
  57. create 4 model parallel groups and 2 data parallel groups as:
  58. 4 model parallel groups:
  59. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  60. 2 data parallel groups:
  61. [g0, g2, g4, g6], [g1, g3, g5, g7]
  62. Note that for efficiency, the caller should make sure adjacent ranks
  63. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  64. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  65. ranks 8 to 15 belong to the second box.
  66. """
  67. log_dist(f'Creating model parallel group with size {model_parallel_size_}',
  68. ranks=[0])
  69. # Get world size and rank. Ensure some consistencies.
  70. assert dist.is_initialized()
  71. world_size = dist.get_world_size()
  72. model_parallel_size = min(model_parallel_size_, world_size)
  73. _ensure_divisibility(world_size, model_parallel_size)
  74. rank = dist.get_rank()
  75. _DATA_PARALLEL_GROUP = None
  76. _MODEL_PARALLEL_GROUP = None
  77. # Build the data parallel groups.
  78. for i in range(model_parallel_size):
  79. ranks = range(i, world_size, model_parallel_size)
  80. group = dist.new_group(ranks)
  81. if i == (rank % model_parallel_size):
  82. _DATA_PARALLEL_GROUP = group
  83. # Build the model parallel groups.
  84. for i in range(world_size // model_parallel_size):
  85. ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
  86. group = dist.new_group(ranks)
  87. if i == (rank // model_parallel_size):
  88. _MODEL_PARALLEL_GROUP = group
  89. return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
  90. def _create_expert_and_data_parallel(expert_parallel_size_):
  91. """
  92. Create expert and data parallel groups.
  93. Note: Caller of this function is responsible to check if the groups already exist.
  94. Example - E + D parallel
  95. world_size = 16
  96. expert_parallel_size = 2 # number of experts in same group
  97. expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
  98. expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
  99. data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
  100. """
  101. assert dist.is_initialized()
  102. log_dist(
  103. f'Creating expert and data parallel groups with size {expert_parallel_size_}',
  104. ranks=[0])
  105. world_size = dist.get_world_size()
  106. rank = dist.get_rank()
  107. _ensure_divisibility(world_size, expert_parallel_size_)
  108. group_name = f"ep_size_{expert_parallel_size_}"
  109. # Build the expert data parallel groups.
  110. global _EXPERT_DATA_PARALLEL_GROUP
  111. # Only create group if it does not already exist
  112. if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
  113. for i in range(expert_parallel_size_):
  114. ranks = range(i, world_size, expert_parallel_size_)
  115. group = dist.new_group(ranks)
  116. log_dist(
  117. f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
  118. [0])
  119. if i == (rank % expert_parallel_size_):
  120. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  121. # Build the expert parallel groups.
  122. global _EXPERT_PARALLEL_GROUP
  123. # Only create group if it does not already exist
  124. if group_name not in _EXPERT_PARALLEL_GROUP:
  125. for i in range(world_size // expert_parallel_size_):
  126. ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
  127. group = dist.new_group(ranks)
  128. log_dist(
  129. f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}',
  130. [0])
  131. if i == (rank // expert_parallel_size_):
  132. _EXPERT_PARALLEL_GROUP[group_name] = group
  133. def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel_size_):
  134. """Generate expert parallel and expert data parallel group ranks list.
  135. Example - E + M + D parallel
  136. world_size = 16
  137. model_degree = 2
  138. expert_degree = 4 # number of experts in same group
  139. mp_group = [0, 1], [2,3], [4,5] ...
  140. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  141. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  142. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
  143. Args:
  144. world_size (int): Distributed world size.
  145. model_parallel_size_ (int): Model parallel group size.
  146. expert_parallel_size_ (int): Expert parallel group size.
  147. Returns:
  148. Expert parallel group ranks and Expert data parallel group ranks list.
  149. """
  150. _ensure_divisibility(world_size, model_parallel_size_)
  151. dp_world_size = world_size // model_parallel_size_
  152. _ensure_divisibility(dp_world_size, expert_parallel_size_)
  153. # Generate data parallel groups
  154. data_parallel_groups = []
  155. dp_group_size = model_parallel_size_
  156. for i in range(dp_group_size):
  157. data_parallel_groups.append(list(range(i, world_size, dp_group_size)))
  158. expert_parallel_groups = []
  159. expert_data_parallel_groups = []
  160. for dp_ranks in data_parallel_groups:
  161. # partition of expert parallel groups, e.g. [0,2,4,6], [8,10,12,14]
  162. part_ep_groups = []
  163. for i in range(0, dp_world_size, expert_parallel_size_):
  164. part_ep_groups.append(dp_ranks[i:i + expert_parallel_size_])
  165. expert_parallel_groups.extend(part_ep_groups)
  166. # zip part_ep_groups get expert data parallel ranks, e.g [0,8],[2,10],[4,12],[6,14]
  167. for expert_dp_ranks in zip(*part_ep_groups):
  168. expert_data_parallel_groups.append(list(expert_dp_ranks))
  169. return expert_parallel_groups, expert_data_parallel_groups
  170. def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu):
  171. """
  172. Create expert and data parallel groups based on MPU (model parallel) group.
  173. Note: Caller of this function is responsible to check if the groups already exist.
  174. Example - E + M + D parallel
  175. world_size = 16
  176. model_degree = 2
  177. expert_degree = 4 # number of experts in same group
  178. mp_group = [0, 1], [2,3], [4,5] ...
  179. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  180. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  181. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
  182. """
  183. assert dist.is_initialized(), "dist is not initialized"
  184. model_parallel_size_ = mpu.get_model_parallel_world_size()
  185. global expert_tensor_parallel_world_size
  186. expert_tensor_parallel_world_size = model_parallel_size_
  187. world_size = dist.get_world_size()
  188. rank = dist.get_rank()
  189. dp_world_size = mpu.get_data_parallel_world_size()
  190. dp_rank = mpu.get_data_parallel_rank()
  191. _ensure_divisibility(world_size, model_parallel_size_)
  192. _ensure_divisibility(dp_world_size, expert_parallel_size_)
  193. log_dist(
  194. f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
  195. [0])
  196. global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP
  197. # Get world size and rank. Ensure some consistencies.
  198. _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
  199. _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()
  200. group_name = f"ep_size_{expert_parallel_size_}"
  201. # Only create groups if they don't already exist
  202. # Need to check conditions outside the group creation loop because of the way torch.dist group creation works
  203. if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
  204. expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
  205. world_size, model_parallel_size_, expert_parallel_size_)
  206. for ranks in expert_parallel_groups:
  207. group = dist.new_group(ranks)
  208. if rank in list(ranks):
  209. _EXPERT_PARALLEL_GROUP[group_name] = group
  210. for ranks in expert_data_parallel_groups:
  211. group = dist.new_group(ranks)
  212. if rank in list(ranks):
  213. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  214. def _get_max_expert_size():
  215. """Get the maximum ep_size from all the created groups."""
  216. assert _EXPERT_PARALLEL_GROUP is not None, "Warning! Process group not initialized"
  217. keylist = []
  218. for key in _EXPERT_PARALLEL_GROUP.keys():
  219. # index 2 is ep_size in the group name: ep_size_<ep_size>
  220. index = 2
  221. keylist.append(int(key.split('_')[index]))
  222. return max(keylist) if len(keylist) > 0 else None
  223. def _get_max_expert_size_name():
  224. """Get the name of the group with max. ep_size"""
  225. return f'ep_size_{_get_max_expert_size()}'
  226. def _get_max_expert_parallel_group():
  227. """Get the max expert parallel size."""
  228. return _get_expert_parallel_group(_get_max_expert_size_name())
  229. def _get_expert_parallel_group(group_name):
  230. """Get the expert parallel group the caller rank belongs to."""
  231. assert group_name in _EXPERT_PARALLEL_GROUP, \
  232. 'expert parallel group is not initialized'
  233. return _EXPERT_PARALLEL_GROUP[group_name]
  234. def _get_expert_parallel_group_dict():
  235. """Get the expert parallel group dict."""
  236. return _EXPERT_PARALLEL_GROUP
  237. def _get_expert_data_parallel_group(group_name):
  238. """Get the expert data parallel group the caller rank belongs to."""
  239. assert group_name in _EXPERT_DATA_PARALLEL_GROUP, \
  240. 'expert data parallel group is not initialized'
  241. return _EXPERT_DATA_PARALLEL_GROUP[group_name]
  242. def _get_expert_data_parallel_group_dict():
  243. """Get the expert data parallel group dict."""
  244. return _EXPERT_DATA_PARALLEL_GROUP
  245. def _clone_world_group():
  246. """Create a clone of the world group
  247. Note: We need to clone the dist world group because we
  248. use dist.get_global_rank() utility function in DeepSpeed at many places.
  249. As that function does not work on dist.group.WORLD, we
  250. need to keep a clone of it.
  251. """
  252. assert dist.is_initialized(), "dist is not initialized"
  253. global _WORLD_GROUP
  254. if _WORLD_GROUP is None:
  255. # If not cloned already, clone the world group
  256. _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size()))
  257. return _WORLD_GROUP
  258. def _get_data_parallel_group():
  259. """Get the data parallel group the caller rank belongs to."""
  260. assert dist.is_initialized(), \
  261. 'dist is not initialized'
  262. global mpu
  263. if mpu is not None:
  264. return mpu.get_data_parallel_group()
  265. # Return the clone of dist world group
  266. return _clone_world_group()
  267. def _get_broadcast_src_rank():
  268. return dist.get_global_rank(_get_data_parallel_group(), 0)
  269. def _get_expert_broadcast_src_rank(group_name):
  270. return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0)
  271. def _get_expert_parallel_world_size(group_name):
  272. """Return world size for the expert parallel group."""
  273. return dist.get_world_size(group=_get_expert_parallel_group(group_name))
  274. def _get_expert_data_parallel_world_size(group_name):
  275. """Return world size for the expert data parallel group."""
  276. return dist.get_world_size(group=_get_expert_data_parallel_group(group_name))
  277. def _get_expert_parallel_rank(group_name):
  278. """Return my rank for the expert parallel group."""
  279. return dist.get_rank(group=_get_expert_parallel_group(group_name))
  280. def _get_expert_parallel_src_rank(group_name):
  281. """Calculate the global rank corresponding to a local rank zero
  282. in the expert parallel group."""
  283. global_rank = dist.get_rank()
  284. local_world_size = _get_expert_parallel_world_size(group_name)
  285. return (global_rank // local_world_size) * local_world_size
  286. def _get_expert_data_parallel_rank(group_name):
  287. """Return my rank for the expert data parallel group."""
  288. return dist.get_rank(group=_get_expert_data_parallel_group(group_name))
  289. def _get_data_parallel_world_size():
  290. """Return world size for the data parallel group."""
  291. global mpu
  292. if mpu is not None:
  293. return mpu.get_data_parallel_world_size()
  294. return dist.get_world_size(group=_get_data_parallel_group())
  295. def _get_model_parallel_world_size():
  296. """Return world size for the model parallel group."""
  297. global mpu
  298. if mpu is not None:
  299. return mpu.get_model_parallel_world_size()
  300. return 1
  301. def _get_data_parallel_rank():
  302. """Return my rank for the data parallel group."""
  303. global mpu
  304. if mpu is not None:
  305. return mpu.get_data_parallel_rank()
  306. return dist.get_rank(group=_get_data_parallel_group())
  307. def _get_expert_model_parallel_world_size():
  308. global expert_tensor_parallel_world_size
  309. return expert_tensor_parallel_world_size