groups.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file
  5. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """
  19. Support different forms of parallelism in DeepSpeed using multiple process groups.
  20. Given that there are multiple scenarios and use-cases, this file is going to be updated
  21. frequently. For now, the group creation needed for the training scenario is being implemented.
  22. For inference and other new scenarios, the code will be either reused or added to this file.
  23. """
  24. from deepspeed import comm as dist
  25. from deepspeed.utils import log_dist
  26. from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
  27. from deepspeed.utils.exceptions import DeprecatedException
  28. from deepspeed.accelerator import get_accelerator
  29. # Expert parallel group that the current rank belongs to.
  30. _EXPERT_PARALLEL_GROUP = {}
  31. # Expert data parallel group that the current rank belongs to.
  32. _EXPERT_DATA_PARALLEL_GROUP = {}
  33. # dist world group needs to be cloned for some cases
  34. _WORLD_GROUP = None
  35. # ZeRO parameter partitioning group that the current rank belongs to.
  36. _ZERO_PARAM_INTRA_PARALLEL_GROUP = None
  37. # global object to maintain mpu object if passed by a Megatron client
  38. mpu = None
  39. # global object that stores tensor parallel world size for experts
  40. expert_tensor_parallel_world_size = 1
  41. # All to All quantized graident communication groups
  42. _ALL_TO_ALL_GROUP = {}
  43. _DATA_PARALLEL_GROUP = None
  44. # Deprecated groups initialize function.
  45. def initialize(ep_size=1, mpu=None):
  46. """ Deprecated function. Retained to inform the users."""
  47. raise DeprecatedException(
  48. "Please do not use the groups.initialize() API as it is deprecated. Instead, pass the desired ep_size to deepspeed.moe.layer.MoE(..,ep_size,..)"
  49. )
  50. def _ensure_divisibility(numerator, denominator):
  51. """Ensure that numerator is divisible by the denominator."""
  52. assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
  53. # Not currently used. Helper function to create a model (tensor) parallel group.
  54. def _create_model_parallel(model_parallel_size_):
  55. """
  56. Initialize model data parallel groups.
  57. Arguments:
  58. model_parallel_size: number of GPUs used to parallelize model.
  59. Returns:
  60. Tuple of data parallel group and model parallel group
  61. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  62. use 2 GPUs to parallelize the model. The present function will
  63. create 4 model parallel groups and 2 data parallel groups as:
  64. 4 model parallel groups:
  65. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  66. 2 data parallel groups:
  67. [g0, g2, g4, g6], [g1, g3, g5, g7]
  68. Note that for efficiency, the caller should make sure adjacent ranks
  69. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  70. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  71. ranks 8 to 15 belong to the second box.
  72. """
  73. log_dist(f'Creating model parallel group with size {model_parallel_size_}', ranks=[0])
  74. # Get world size and rank. Ensure some consistencies.
  75. assert dist.is_initialized()
  76. world_size = dist.get_world_size()
  77. model_parallel_size = min(model_parallel_size_, world_size)
  78. _ensure_divisibility(world_size, model_parallel_size)
  79. rank = dist.get_rank()
  80. _DATA_PARALLEL_GROUP = None
  81. _MODEL_PARALLEL_GROUP = None
  82. # Build the data parallel groups.
  83. for i in range(model_parallel_size):
  84. ranks = range(i, world_size, model_parallel_size)
  85. group = dist.new_group(ranks)
  86. if i == (rank % model_parallel_size):
  87. _DATA_PARALLEL_GROUP = group
  88. # Build the model parallel groups.
  89. for i in range(world_size // model_parallel_size):
  90. ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
  91. group = dist.new_group(ranks)
  92. if i == (rank // model_parallel_size):
  93. _MODEL_PARALLEL_GROUP = group
  94. return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
  95. def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expert_parallel_=False):
  96. """
  97. Create expert and data parallel groups.
  98. Note: Caller of this function is responsible to check if the groups already exist.
  99. Example - E + D parallel
  100. world_size = 16
  101. expert_parallel_size = 2 # number of experts in same group
  102. expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
  103. expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
  104. data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
  105. use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
  106. """
  107. assert dist.is_initialized()
  108. log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
  109. world_size = dist.get_world_size()
  110. pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
  111. rank = dist.get_rank()
  112. pp_stride = world_size // pp_world_size
  113. _ensure_divisibility(pp_stride, expert_parallel_size_)
  114. group_name = f"ep_size_{expert_parallel_size_}"
  115. # Build the expert data parallel groups.
  116. global _EXPERT_DATA_PARALLEL_GROUP
  117. ep_stride = pp_stride // expert_parallel_size_
  118. # Only create group if it does not already exist
  119. if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
  120. for pp_stage_start in range(0, world_size, pp_stride):
  121. for i in range(expert_parallel_size_):
  122. if use_data_before_expert_parallel_:
  123. ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride)
  124. else:
  125. ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_)
  126. group = dist.new_group(ranks)
  127. log_dist(
  128. f'Creating expert data parallel process group named {group_name} '
  129. f'with ranks: {list(ranks)}', [0])
  130. if rank in ranks:
  131. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  132. # Build the expert parallel groups.
  133. global _EXPERT_PARALLEL_GROUP
  134. # Only create group if it does not already exist
  135. if group_name not in _EXPERT_PARALLEL_GROUP:
  136. if use_data_before_expert_parallel_:
  137. for pp_stage_start in range(0, world_size, pp_stride):
  138. for i in range(ep_stride):
  139. ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride)
  140. group = dist.new_group(ranks)
  141. log_dist(
  142. f'creating expert parallel process group named {group_name} '
  143. f'with ranks: {list(ranks)}', [0])
  144. if rank in ranks:
  145. _EXPERT_PARALLEL_GROUP[group_name] = group
  146. else:
  147. for i in range(world_size // expert_parallel_size_):
  148. ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
  149. group = dist.new_group(ranks)
  150. log_dist(f'creating expert parallel process group named {group_name} '
  151. f'with ranks: {list(ranks)}', [0])
  152. if rank in ranks:
  153. _EXPERT_PARALLEL_GROUP[group_name] = group
  154. def _get_expert_parallel_ranks(world_size,
  155. tensor_parallel_size_,
  156. expert_parallel_size_,
  157. pipeline_parallel_size_=1,
  158. use_data_before_expert_parallel_=False):
  159. """Generate expert parallel and expert data parallel group ranks list.
  160. Example - E + M + D parallel
  161. world_size = 16
  162. model_degree = 2
  163. expert_degree = 4 # number of experts in same group
  164. mp_group = [0, 1], [2,3], [4,5] ...
  165. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  166. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  167. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
  168. Args:
  169. world_size (int): Distributed world size.
  170. tensor_parallel_size_ (int): Tensor parallel group size.
  171. expert_parallel_size_ (int): Expert parallel group size.
  172. pipeline_parallel_size_ (int): Pipeline parallel group size
  173. use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
  174. Returns:
  175. Expert parallel group ranks and Expert data parallel group ranks list.
  176. """
  177. _ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_)
  178. dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_)
  179. _ensure_divisibility(dp_world_size, expert_parallel_size_)
  180. # Generate data parallel groups
  181. data_parallel_groups = []
  182. dp_group_size = tensor_parallel_size_
  183. pp_stride = world_size // pipeline_parallel_size_
  184. if use_data_before_expert_parallel_:
  185. dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_
  186. for pp_stage_start in range(0, world_size, pp_stride):
  187. pp_stage_next = pp_stage_start + pp_stride
  188. for i in range(dp_group_size):
  189. data_parallel_groups.append(list())
  190. for ds in range(dp_stride):
  191. # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
  192. # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
  193. data_parallel_groups[-1].extend(
  194. list(
  195. range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next,
  196. dp_stride * tensor_parallel_size_)))
  197. else:
  198. for pp_stage_start in range(0, world_size, pp_stride):
  199. pp_stage_next = pp_stage_start + pp_stride
  200. for i in range(dp_group_size):
  201. data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size)))
  202. expert_parallel_groups = []
  203. expert_data_parallel_groups = []
  204. for dp_ranks in data_parallel_groups:
  205. # partition of expert parallel groups, e.g. [0,2,4,6], [8,10,12,14]
  206. part_ep_groups = []
  207. for i in range(0, dp_world_size, expert_parallel_size_):
  208. part_ep_groups.append(dp_ranks[i:i + expert_parallel_size_])
  209. expert_parallel_groups.extend(part_ep_groups)
  210. # zip part_ep_groups get expert data parallel ranks, e.g [0,8],[2,10],[4,12],[6,14]
  211. for expert_dp_ranks in zip(*part_ep_groups):
  212. expert_data_parallel_groups.append(list(expert_dp_ranks))
  213. return expert_parallel_groups, expert_data_parallel_groups
  214. def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False):
  215. """
  216. Create expert and data parallel groups based on MPU (model parallel) group.
  217. Note: Caller of this function is responsible to check if the groups already exist.
  218. Example - E + M + D parallel
  219. world_size = 16
  220. model_degree = 2
  221. expert_degree = 4 # number of experts in same group
  222. mp_group = [0, 1], [2,3], [4,5] ...
  223. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  224. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  225. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
  226. """
  227. assert dist.is_initialized(), "dist is not initialized"
  228. tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu)
  229. global expert_tensor_parallel_world_size
  230. expert_tensor_parallel_world_size = tensor_parallel_size_
  231. world_size = dist.get_world_size()
  232. rank = dist.get_rank()
  233. dp_world_size = mpu.get_data_parallel_world_size()
  234. pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
  235. _ensure_divisibility(world_size, tensor_parallel_size_)
  236. _ensure_divisibility(dp_world_size, expert_parallel_size_)
  237. log_dist(
  238. f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, "
  239. f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, "
  240. f"world size {world_size}, dp world size {dp_world_size}", [0])
  241. global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP
  242. group_name = f"ep_size_{expert_parallel_size_}"
  243. # Only create groups if they don't already exist
  244. # Need to check conditions outside the group creation loop because of the way torch.dist group creation works
  245. if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
  246. expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
  247. world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_)
  248. for ranks in expert_parallel_groups:
  249. group = dist.new_group(ranks)
  250. if rank in list(ranks):
  251. _EXPERT_PARALLEL_GROUP[group_name] = group
  252. for ranks in expert_data_parallel_groups:
  253. group = dist.new_group(ranks)
  254. if rank in list(ranks):
  255. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  256. def _get_max_expert_size():
  257. """Get the maximum ep_size from all the created groups."""
  258. assert _EXPERT_PARALLEL_GROUP is not None, "Warning! Process group not initialized"
  259. keylist = []
  260. for key in _EXPERT_PARALLEL_GROUP.keys():
  261. # index 2 is ep_size in the group name: ep_size_<ep_size>
  262. index = 2
  263. keylist.append(int(key.split('_')[index]))
  264. return max(keylist) if len(keylist) > 0 else None
  265. def _get_max_expert_size_name():
  266. """Get the name of the group with max. ep_size"""
  267. return f'ep_size_{_get_max_expert_size()}'
  268. def _get_max_expert_parallel_group():
  269. """Get the max expert parallel size."""
  270. return _get_expert_parallel_group(_get_max_expert_size_name())
  271. def _get_expert_parallel_group(group_name):
  272. """Get the expert parallel group the caller rank belongs to."""
  273. assert group_name in _EXPERT_PARALLEL_GROUP, \
  274. 'expert parallel group is not initialized'
  275. return _EXPERT_PARALLEL_GROUP[group_name]
  276. def _get_expert_parallel_group_dict():
  277. """Get the expert parallel group dict."""
  278. return _EXPERT_PARALLEL_GROUP
  279. def _get_expert_data_parallel_group(group_name):
  280. """Get the expert data parallel group the caller rank belongs to."""
  281. assert group_name in _EXPERT_DATA_PARALLEL_GROUP, \
  282. 'expert data parallel group is not initialized'
  283. return _EXPERT_DATA_PARALLEL_GROUP[group_name]
  284. def _get_expert_data_parallel_group_dict():
  285. """Get the expert data parallel group dict."""
  286. return _EXPERT_DATA_PARALLEL_GROUP
  287. def _clone_world_group():
  288. """Create a clone of the world group
  289. Note: We need to clone the dist world group because we
  290. use dist.get_global_rank() utility function in DeepSpeed at many places.
  291. As that function does not work on dist.group.WORLD, we
  292. need to keep a clone of it.
  293. """
  294. assert dist.is_initialized(), "dist is not initialized"
  295. global _WORLD_GROUP
  296. if _WORLD_GROUP is None:
  297. # If not cloned already, clone the world group
  298. _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size()))
  299. return _WORLD_GROUP
  300. def _get_local_all_to_all_group():
  301. assert dist.is_initialized(), 'dist is not initialized'
  302. global _ALL_TO_ALL_GROUP
  303. device_per_node = get_accelerator().device_count()
  304. num_local = dist.get_world_size() // device_per_node
  305. if num_local == 0 and dist.get_world_size() > 0:
  306. assert dist.get_world_size() >= 1, 'num_gpus must >=1, cannot initialize All-To-All'
  307. cur_rank = []
  308. for i in range(dist.get_world_size()):
  309. cur_rank.append(i)
  310. _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=cur_rank)
  311. elif num_local == 1:
  312. assert dist.get_world_size(
  313. ) == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All'
  314. _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=[i for i in range(device_per_node)])
  315. else:
  316. assert dist.get_world_size() > device_per_node, 'num_nodes<2 cannot initialize All-To-All'
  317. for i in range(num_local):
  318. local_rank = [j + device_per_node * i for j in range(device_per_node)]
  319. _ALL_TO_ALL_GROUP[f"local_{i}"] = dist.new_group(ranks=local_rank)
  320. for i in range(device_per_node):
  321. cur_rank = []
  322. for j in range(num_local):
  323. cur_rank.append(i + j * device_per_node)
  324. _ALL_TO_ALL_GROUP[f"global_{i}"] = dist.new_group(ranks=cur_rank)
  325. return _ALL_TO_ALL_GROUP
  326. def _get_data_parallel_group():
  327. """Get the data parallel group the caller rank belongs to."""
  328. assert dist.is_initialized(), 'dist is not initialized'
  329. global mpu
  330. if mpu is not None:
  331. return mpu.get_data_parallel_group()
  332. # Return the clone of dist world group
  333. return _clone_world_group()
  334. def _get_broadcast_src_rank():
  335. return dist.get_global_rank(_get_sequence_data_parallel_group(), 0)
  336. def _get_expert_broadcast_src_rank(group_name):
  337. return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0)
  338. def _get_expert_parallel_world_size(group_name):
  339. """Return world size for the expert parallel group."""
  340. return dist.get_world_size(group=_get_expert_parallel_group(group_name))
  341. def _get_expert_data_parallel_world_size(group_name):
  342. """Return world size for the expert data parallel group."""
  343. return dist.get_world_size(group=_get_expert_data_parallel_group(group_name))
  344. def _get_expert_parallel_rank(group_name):
  345. """Return my rank for the expert parallel group."""
  346. return dist.get_rank(group=_get_expert_parallel_group(group_name))
  347. def _get_expert_parallel_src_rank(group_name):
  348. """Calculate the global rank corresponding to a local rank zero
  349. in the expert parallel group."""
  350. global_rank = dist.get_rank()
  351. local_world_size = _get_expert_parallel_world_size(group_name)
  352. return (global_rank // local_world_size) * local_world_size
  353. def _get_expert_data_parallel_rank(group_name):
  354. """Return my rank for the expert data parallel group."""
  355. return dist.get_rank(group=_get_expert_data_parallel_group(group_name))
  356. def _get_data_parallel_world_size():
  357. """Return world size for the data parallel group."""
  358. global mpu
  359. if mpu is not None:
  360. return mpu.get_data_parallel_world_size()
  361. return dist.get_world_size(group=_get_data_parallel_group())
  362. def _get_model_parallel_world_size():
  363. """Return world size for the model parallel group."""
  364. global mpu
  365. if mpu is not None:
  366. return mpu.get_model_parallel_world_size()
  367. return 1
  368. def _get_data_parallel_rank():
  369. """Return my rank for the data parallel group."""
  370. return dist.get_rank(group=_get_data_parallel_group())
  371. def _get_sequence_parallel_world_size():
  372. """Return world size for the model parallel group."""
  373. global mpu
  374. if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'):
  375. return mpu.get_sequence_parallel_world_size()
  376. return 1
  377. def _get_sequence_parallel_rank():
  378. """Return my rank for the data parallel group."""
  379. global mpu
  380. if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
  381. return mpu.get_sequence_parallel_rank()
  382. return 0
  383. def _get_sequence_parallel_group():
  384. global mpu
  385. if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'):
  386. return mpu.get_sequence_parallel_group()
  387. return None
  388. def _get_sequence_data_parallel_world_size():
  389. """Return world size for the model parallel group."""
  390. global mpu
  391. if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_world_size'):
  392. return mpu.get_sequence_data_parallel_world_size()
  393. return _get_data_parallel_world_size()
  394. def _get_sequence_data_parallel_rank():
  395. """Return my rank for the data parallel group."""
  396. global mpu
  397. if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_rank'):
  398. return mpu.get_sequence_data_parallel_rank()
  399. return _get_data_parallel_rank()
  400. def _get_sequence_data_parallel_group():
  401. global mpu
  402. # When sequence parallelism is enabled, the process group for zero sharding and
  403. # gradient allreduce must be across both dimensions of data and sequence parallelism.
  404. if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_group'):
  405. return mpu.get_sequence_data_parallel_group()
  406. return _get_data_parallel_group()
  407. def _get_expert_model_parallel_world_size():
  408. global expert_tensor_parallel_world_size
  409. return expert_tensor_parallel_world_size
  410. def _create_zero_param_parallel_group(group_size):
  411. """
  412. Create parameter partitioning group within ZeRO data parallel groups.
  413. Example - ZP + D parallel
  414. world_size = 16
  415. zero_hpz_partition_size = 2 # number of ranks with replicated params (dual partitioning)
  416. zero_param_intra_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - segmented (subgroup) with rep partition
  417. data_parallel_group = [0,1,...,15] - all reduce is on ZeRO model
  418. """
  419. assert dist.is_initialized()
  420. global _ZERO_PARAM_INTRA_PARALLEL_GROUP
  421. # Only create group if it does not already exist
  422. assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is None, \
  423. 'ZeRO parameter intra parallel group is already initialized'
  424. world_size = dist.get_world_size()
  425. rank = dist.get_rank()
  426. zero_param_parallel_size_ = min(group_size, world_size)
  427. _ensure_divisibility(world_size, zero_param_parallel_size_)
  428. # Build the ZeRO param intra parallel groups.
  429. for i in range(world_size // zero_param_parallel_size_):
  430. ranks = range(i * zero_param_parallel_size_, (i + 1) * zero_param_parallel_size_)
  431. group = dist.new_group(ranks)
  432. if i == (rank // zero_param_parallel_size_):
  433. _ZERO_PARAM_INTRA_PARALLEL_GROUP = group
  434. def _get_zero_param_intra_parallel_group():
  435. """Get the ZeRO parameter partitioning intra parallel group the caller rank belongs to."""
  436. #assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is not None, \
  437. # 'ZeRO parameter partitioning group is not initialized'
  438. #TODO: Add warning
  439. return _ZERO_PARAM_INTRA_PARALLEL_GROUP
  440. def _zero_param_parallel_is_initialized():
  441. """Check if ZeRO data parallel with parameter partititioning groups are initialized."""
  442. ###TODO: assert that MPU is not set
  443. if _ZERO_PARAM_INTRA_PARALLEL_GROUP is None and _DATA_PARALLEL_GROUP is None:
  444. return False
  445. def _get_zero_param_intra_parallel_rank_in_mygroup():
  446. """Return my rank for the ZeRO parameter inter parallel group."""
  447. return dist.get_rank(group=_get_zero_param_intra_parallel_group())
  448. def _get_zero_param_intra_parallel_group_world_size():
  449. """Return world size for the ZeRO parameter parallel group."""
  450. return dist.get_world_size(group=_get_zero_param_intra_parallel_group())
  451. def _get_zero_param_intra_parallel_group_ranks():
  452. """Return all ranks for the ZeRO parameter intra parallel group."""
  453. return dist.get_all_ranks_from_group(group=_get_zero_param_intra_parallel_group())