groups.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file
  5. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """
  19. Support different forms of parallelism in DeepSpeed using multiple process groups.
  20. Given that there are multiple scenarios and use-cases, this file is going to be updated
  21. frequently. For now, the group creation needed for the training scenario is being implemented.
  22. For inference and other new scenarios, the code will be either reused or added to this file.
  23. """
  24. from deepspeed import comm as dist
  25. from deepspeed.utils import log_dist
  26. from deepspeed.utils.exceptions import DeprecatedException
  27. from deepspeed.accelerator import get_accelerator
  28. # Expert parallel group that the current rank belongs to.
  29. _EXPERT_PARALLEL_GROUP = {}
  30. # Expert data parallel group that the current rank belongs to.
  31. _EXPERT_DATA_PARALLEL_GROUP = {}
  32. # dist world group needs to be cloned for some cases
  33. _WORLD_GROUP = None
  34. # ZeRO parameter partitioning group that the current rank belongs to.
  35. _ZERO_PARAM_INTRA_PARALLEL_GROUP = None
  36. # global object to maintain mpu object if passed by a Megatron client
  37. mpu = None
  38. # global object that stores tensor parallel world size for experts
  39. expert_tensor_parallel_world_size = 1
  40. # All to All quantized graident communication groups
  41. _ALL_TO_ALL_GROUP = {}
  42. _DATA_PARALLEL_GROUP = None
  43. # Deprecated groups initialize function.
  44. def initialize(ep_size=1, mpu=None):
  45. """ Deprecated function. Retained to inform the users."""
  46. raise DeprecatedException(
  47. "Please do not use the groups.initialize() API as it is deprecated. Instead, pass the desired ep_size to deepspeed.moe.layer.MoE(..,ep_size,..)"
  48. )
  49. def _ensure_divisibility(numerator, denominator):
  50. """Ensure that numerator is divisible by the denominator."""
  51. assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
  52. # Not currently used. Helper function to create a model (tensor) parallel group.
  53. def _create_model_parallel(model_parallel_size_):
  54. """
  55. Initialize model data parallel groups.
  56. Arguments:
  57. model_parallel_size: number of GPUs used to parallelize model.
  58. Returns:
  59. Tuple of data parallel group and model parallel group
  60. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  61. use 2 GPUs to parallelize the model. The present function will
  62. create 4 model parallel groups and 2 data parallel groups as:
  63. 4 model parallel groups:
  64. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  65. 2 data parallel groups:
  66. [g0, g2, g4, g6], [g1, g3, g5, g7]
  67. Note that for efficiency, the caller should make sure adjacent ranks
  68. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  69. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  70. ranks 8 to 15 belong to the second box.
  71. """
  72. log_dist(f'Creating model parallel group with size {model_parallel_size_}', ranks=[0])
  73. # Get world size and rank. Ensure some consistencies.
  74. assert dist.is_initialized()
  75. world_size = dist.get_world_size()
  76. model_parallel_size = min(model_parallel_size_, world_size)
  77. _ensure_divisibility(world_size, model_parallel_size)
  78. rank = dist.get_rank()
  79. _DATA_PARALLEL_GROUP = None
  80. _MODEL_PARALLEL_GROUP = None
  81. # Build the data parallel groups.
  82. for i in range(model_parallel_size):
  83. ranks = range(i, world_size, model_parallel_size)
  84. group = dist.new_group(ranks)
  85. if i == (rank % model_parallel_size):
  86. _DATA_PARALLEL_GROUP = group
  87. # Build the model parallel groups.
  88. for i in range(world_size // model_parallel_size):
  89. ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
  90. group = dist.new_group(ranks)
  91. if i == (rank // model_parallel_size):
  92. _MODEL_PARALLEL_GROUP = group
  93. return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
  94. def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expert_parallel_=False):
  95. """
  96. Create expert and data parallel groups.
  97. Note: Caller of this function is responsible to check if the groups already exist.
  98. Example - E + D parallel
  99. world_size = 16
  100. expert_parallel_size = 2 # number of experts in same group
  101. expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
  102. expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
  103. data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
  104. use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
  105. """
  106. assert dist.is_initialized()
  107. log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
  108. world_size = dist.get_world_size()
  109. rank = dist.get_rank()
  110. _ensure_divisibility(world_size, expert_parallel_size_)
  111. group_name = f"ep_size_{expert_parallel_size_}"
  112. # Build the expert data parallel groups.
  113. global _EXPERT_DATA_PARALLEL_GROUP
  114. ep_stride = world_size // expert_parallel_size_
  115. # Only create group if it does not already exist
  116. if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
  117. for i in range(expert_parallel_size_):
  118. if use_data_before_expert_parallel_:
  119. ranks = range(i * ep_stride, (i + 1) * ep_stride)
  120. else:
  121. ranks = range(i, world_size, expert_parallel_size_)
  122. group = dist.new_group(ranks)
  123. log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0])
  124. if use_data_before_expert_parallel_:
  125. if i == (rank // ep_stride):
  126. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  127. else:
  128. if i == (rank % expert_parallel_size_):
  129. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  130. # Build the expert parallel groups.
  131. global _EXPERT_PARALLEL_GROUP
  132. # Only create group if it does not already exist
  133. if group_name not in _EXPERT_PARALLEL_GROUP:
  134. if use_data_before_expert_parallel_:
  135. for i in range(ep_stride):
  136. ranks = range(i, world_size, ep_stride)
  137. group = dist.new_group(ranks)
  138. log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
  139. if i == (rank % ep_stride):
  140. _EXPERT_PARALLEL_GROUP[group_name] = group
  141. else:
  142. for i in range(world_size // expert_parallel_size_):
  143. ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
  144. group = dist.new_group(ranks)
  145. log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
  146. if i == (rank // expert_parallel_size_):
  147. _EXPERT_PARALLEL_GROUP[group_name] = group
  148. def _get_expert_parallel_ranks(world_size,
  149. model_parallel_size_,
  150. expert_parallel_size_,
  151. use_data_before_expert_parallel_=False):
  152. """Generate expert parallel and expert data parallel group ranks list.
  153. Example - E + M + D parallel
  154. world_size = 16
  155. model_degree = 2
  156. expert_degree = 4 # number of experts in same group
  157. mp_group = [0, 1], [2,3], [4,5] ...
  158. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  159. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  160. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
  161. Args:
  162. world_size (int): Distributed world size.
  163. model_parallel_size_ (int): Model parallel group size.
  164. expert_parallel_size_ (int): Expert parallel group size.
  165. use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
  166. Returns:
  167. Expert parallel group ranks and Expert data parallel group ranks list.
  168. """
  169. _ensure_divisibility(world_size, model_parallel_size_)
  170. dp_world_size = world_size // model_parallel_size_
  171. _ensure_divisibility(dp_world_size, expert_parallel_size_)
  172. # Generate data parallel groups
  173. data_parallel_groups = []
  174. dp_group_size = model_parallel_size_
  175. if use_data_before_expert_parallel_:
  176. dp_stride = world_size // expert_parallel_size_ // model_parallel_size_
  177. for i in range(dp_group_size):
  178. data_parallel_groups.append(list())
  179. for ds in range(dp_stride):
  180. # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
  181. # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
  182. data_parallel_groups[-1].extend(
  183. list(range(i + ds * model_parallel_size_, world_size, dp_stride * model_parallel_size_)))
  184. else:
  185. for i in range(dp_group_size):
  186. data_parallel_groups.append(list(range(i, world_size, dp_group_size)))
  187. expert_parallel_groups = []
  188. expert_data_parallel_groups = []
  189. for dp_ranks in data_parallel_groups:
  190. # partition of expert parallel groups, e.g. [0,2,4,6], [8,10,12,14]
  191. part_ep_groups = []
  192. for i in range(0, dp_world_size, expert_parallel_size_):
  193. part_ep_groups.append(dp_ranks[i:i + expert_parallel_size_])
  194. expert_parallel_groups.extend(part_ep_groups)
  195. # zip part_ep_groups get expert data parallel ranks, e.g [0,8],[2,10],[4,12],[6,14]
  196. for expert_dp_ranks in zip(*part_ep_groups):
  197. expert_data_parallel_groups.append(list(expert_dp_ranks))
  198. return expert_parallel_groups, expert_data_parallel_groups
  199. def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False):
  200. """
  201. Create expert and data parallel groups based on MPU (model parallel) group.
  202. Note: Caller of this function is responsible to check if the groups already exist.
  203. Example - E + M + D parallel
  204. world_size = 16
  205. model_degree = 2
  206. expert_degree = 4 # number of experts in same group
  207. mp_group = [0, 1], [2,3], [4,5] ...
  208. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  209. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  210. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
  211. """
  212. assert dist.is_initialized(), "dist is not initialized"
  213. model_parallel_size_ = mpu.get_model_parallel_world_size()
  214. global expert_tensor_parallel_world_size
  215. expert_tensor_parallel_world_size = model_parallel_size_
  216. world_size = dist.get_world_size()
  217. rank = dist.get_rank()
  218. dp_world_size = mpu.get_data_parallel_world_size()
  219. dp_rank = mpu.get_data_parallel_rank()
  220. _ensure_divisibility(world_size, model_parallel_size_)
  221. _ensure_divisibility(dp_world_size, expert_parallel_size_)
  222. log_dist(
  223. f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
  224. [0])
  225. global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP
  226. # Get world size and rank. Ensure some consistencies.
  227. _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
  228. _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()
  229. group_name = f"ep_size_{expert_parallel_size_}"
  230. # Only create groups if they don't already exist
  231. # Need to check conditions outside the group creation loop because of the way torch.dist group creation works
  232. if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
  233. expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
  234. world_size, model_parallel_size_, expert_parallel_size_, use_data_before_expert_parallel_)
  235. for ranks in expert_parallel_groups:
  236. group = dist.new_group(ranks)
  237. if rank in list(ranks):
  238. _EXPERT_PARALLEL_GROUP[group_name] = group
  239. for ranks in expert_data_parallel_groups:
  240. group = dist.new_group(ranks)
  241. if rank in list(ranks):
  242. _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
  243. def _get_max_expert_size():
  244. """Get the maximum ep_size from all the created groups."""
  245. assert _EXPERT_PARALLEL_GROUP is not None, "Warning! Process group not initialized"
  246. keylist = []
  247. for key in _EXPERT_PARALLEL_GROUP.keys():
  248. # index 2 is ep_size in the group name: ep_size_<ep_size>
  249. index = 2
  250. keylist.append(int(key.split('_')[index]))
  251. return max(keylist) if len(keylist) > 0 else None
  252. def _get_max_expert_size_name():
  253. """Get the name of the group with max. ep_size"""
  254. return f'ep_size_{_get_max_expert_size()}'
  255. def _get_max_expert_parallel_group():
  256. """Get the max expert parallel size."""
  257. return _get_expert_parallel_group(_get_max_expert_size_name())
  258. def _get_expert_parallel_group(group_name):
  259. """Get the expert parallel group the caller rank belongs to."""
  260. assert group_name in _EXPERT_PARALLEL_GROUP, \
  261. 'expert parallel group is not initialized'
  262. return _EXPERT_PARALLEL_GROUP[group_name]
  263. def _get_expert_parallel_group_dict():
  264. """Get the expert parallel group dict."""
  265. return _EXPERT_PARALLEL_GROUP
  266. def _get_expert_data_parallel_group(group_name):
  267. """Get the expert data parallel group the caller rank belongs to."""
  268. assert group_name in _EXPERT_DATA_PARALLEL_GROUP, \
  269. 'expert data parallel group is not initialized'
  270. return _EXPERT_DATA_PARALLEL_GROUP[group_name]
  271. def _get_expert_data_parallel_group_dict():
  272. """Get the expert data parallel group dict."""
  273. return _EXPERT_DATA_PARALLEL_GROUP
  274. def _clone_world_group():
  275. """Create a clone of the world group
  276. Note: We need to clone the dist world group because we
  277. use dist.get_global_rank() utility function in DeepSpeed at many places.
  278. As that function does not work on dist.group.WORLD, we
  279. need to keep a clone of it.
  280. """
  281. assert dist.is_initialized(), "dist is not initialized"
  282. global _WORLD_GROUP
  283. if _WORLD_GROUP is None:
  284. # If not cloned already, clone the world group
  285. _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size()))
  286. return _WORLD_GROUP
  287. def _get_local_all_to_all_group():
  288. assert dist.is_initialized(), 'dist is not initialized'
  289. global _ALL_TO_ALL_GROUP
  290. device_per_node = get_accelerator().device_count()
  291. num_local = dist.get_world_size() // device_per_node
  292. if num_local == 0 and dist.get_world_size() > 0:
  293. assert dist.get_world_size() >= 1, 'num_gpus must >=1, cannot initialize All-To-All'
  294. cur_rank = []
  295. for i in range(dist.get_world_size()):
  296. cur_rank.append(i)
  297. _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=cur_rank)
  298. elif num_local == 1:
  299. assert dist.get_world_size(
  300. ) == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All'
  301. _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=[i for i in range(device_per_node)])
  302. else:
  303. assert dist.get_world_size() > device_per_node, 'num_nodes<2 cannot initialize All-To-All'
  304. for i in range(num_local):
  305. local_rank = [j + device_per_node * i for j in range(device_per_node)]
  306. _ALL_TO_ALL_GROUP[f"local_{i}"] = dist.new_group(ranks=local_rank)
  307. for i in range(device_per_node):
  308. cur_rank = []
  309. for j in range(num_local):
  310. cur_rank.append(i + j * device_per_node)
  311. _ALL_TO_ALL_GROUP[f"global_{i}"] = dist.new_group(ranks=cur_rank)
  312. return _ALL_TO_ALL_GROUP
  313. def _get_data_parallel_group():
  314. """Get the data parallel group the caller rank belongs to."""
  315. assert dist.is_initialized(), 'dist is not initialized'
  316. global mpu
  317. if mpu is not None:
  318. return mpu.get_data_parallel_group()
  319. # Return the clone of dist world group
  320. return _clone_world_group()
  321. def _get_broadcast_src_rank():
  322. return dist.get_global_rank(_get_sequence_data_parallel_group(), 0)
  323. def _get_expert_broadcast_src_rank(group_name):
  324. return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0)
  325. def _get_expert_parallel_world_size(group_name):
  326. """Return world size for the expert parallel group."""
  327. return dist.get_world_size(group=_get_expert_parallel_group(group_name))
  328. def _get_expert_data_parallel_world_size(group_name):
  329. """Return world size for the expert data parallel group."""
  330. return dist.get_world_size(group=_get_expert_data_parallel_group(group_name))
  331. def _get_expert_parallel_rank(group_name):
  332. """Return my rank for the expert parallel group."""
  333. return dist.get_rank(group=_get_expert_parallel_group(group_name))
  334. def _get_expert_parallel_src_rank(group_name):
  335. """Calculate the global rank corresponding to a local rank zero
  336. in the expert parallel group."""
  337. global_rank = dist.get_rank()
  338. local_world_size = _get_expert_parallel_world_size(group_name)
  339. return (global_rank // local_world_size) * local_world_size
  340. def _get_expert_data_parallel_rank(group_name):
  341. """Return my rank for the expert data parallel group."""
  342. return dist.get_rank(group=_get_expert_data_parallel_group(group_name))
  343. def _get_data_parallel_world_size():
  344. """Return world size for the data parallel group."""
  345. global mpu
  346. if mpu is not None:
  347. return mpu.get_data_parallel_world_size()
  348. return dist.get_world_size(group=_get_data_parallel_group())
  349. def _get_model_parallel_world_size():
  350. """Return world size for the model parallel group."""
  351. global mpu
  352. if mpu is not None:
  353. return mpu.get_model_parallel_world_size()
  354. return 1
  355. def _get_data_parallel_rank():
  356. """Return my rank for the data parallel group."""
  357. return dist.get_rank(group=_get_data_parallel_group())
  358. def _get_sequence_parallel_world_size():
  359. """Return world size for the model parallel group."""
  360. global mpu
  361. if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'):
  362. return mpu.get_sequence_parallel_world_size()
  363. return 1
  364. def _get_sequence_parallel_rank():
  365. """Return my rank for the data parallel group."""
  366. global mpu
  367. if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
  368. return mpu.get_sequence_parallel_rank()
  369. return 0
  370. def _get_sequence_parallel_group():
  371. global mpu
  372. if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'):
  373. return mpu.get_sequence_parallel_group()
  374. return None
  375. def _get_sequence_data_parallel_world_size():
  376. """Return world size for the model parallel group."""
  377. global mpu
  378. if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_world_size'):
  379. return mpu.get_sequence_data_parallel_world_size()
  380. return _get_data_parallel_world_size()
  381. def _get_sequence_data_parallel_rank():
  382. """Return my rank for the data parallel group."""
  383. global mpu
  384. if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_rank'):
  385. return mpu.get_sequence_data_parallel_rank()
  386. return _get_data_parallel_rank()
  387. def _get_sequence_data_parallel_group():
  388. global mpu
  389. # When sequence parallelism is enabled, the process group for zero sharding and
  390. # gradient allreduce must be across both dimensions of data and sequence parallelism.
  391. if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_group'):
  392. return mpu.get_sequence_data_parallel_group()
  393. return _get_data_parallel_group()
  394. def _get_expert_model_parallel_world_size():
  395. global expert_tensor_parallel_world_size
  396. return expert_tensor_parallel_world_size
  397. def _create_zero_param_parallel_group(group_size):
  398. """
  399. Create parameter partitioning group within ZeRO data parallel groups.
  400. Example - ZP + D parallel
  401. world_size = 16
  402. zero_hpz_partition_size = 2 # number of ranks with replicated params (dual partitioning)
  403. zero_param_intra_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - segmented (subgroup) with rep partition
  404. data_parallel_group = [0,1,...,15] - all reduce is on ZeRO model
  405. """
  406. assert dist.is_initialized()
  407. global _ZERO_PARAM_INTRA_PARALLEL_GROUP
  408. # Only create group if it does not already exist
  409. assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is None, \
  410. 'ZeRO parameter intra parallel group is already initialized'
  411. world_size = dist.get_world_size()
  412. rank = dist.get_rank()
  413. zero_param_parallel_size_ = min(group_size, world_size)
  414. _ensure_divisibility(world_size, zero_param_parallel_size_)
  415. # Build the ZeRO param intra parallel groups.
  416. for i in range(world_size // zero_param_parallel_size_):
  417. ranks = range(i * zero_param_parallel_size_, (i + 1) * zero_param_parallel_size_)
  418. group = dist.new_group(ranks)
  419. if i == (rank // zero_param_parallel_size_):
  420. _ZERO_PARAM_INTRA_PARALLEL_GROUP = group
  421. def _get_zero_param_intra_parallel_group():
  422. """Get the ZeRO parameter partitioning intra parallel group the caller rank belongs to."""
  423. #assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is not None, \
  424. # 'ZeRO parameter partitioning group is not initialized'
  425. #TODO: Add warning
  426. return _ZERO_PARAM_INTRA_PARALLEL_GROUP
  427. def _zero_param_parallel_is_initialized():
  428. """Check if ZeRO data parallel with parameter partititioning groups are initialized."""
  429. ###TODO: assert that MPU is not set
  430. if _ZERO_PARAM_INTRA_PARALLEL_GROUP is None and _DATA_PARALLEL_GROUP is None:
  431. return False
  432. def _get_zero_param_intra_parallel_rank_in_mygroup():
  433. """Return my rank for the ZeRO parameter inter parallel group."""
  434. return dist.get_rank(group=_get_zero_param_intra_parallel_group())
  435. def _get_zero_param_intra_parallel_group_world_size():
  436. """Return world size for the ZeRO parameter parallel group."""
  437. return dist.get_world_size(group=_get_zero_param_intra_parallel_group())
  438. def _get_zero_param_intra_parallel_group_ranks():
  439. """Return all ranks for the ZeRO parameter intra parallel group."""
  440. return dist.get_all_ranks_from_group(group=_get_zero_param_intra_parallel_group())