groups.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. # The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file
  5. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """
  19. Support, expert, data, and model (only megatron-style) parallelism in DeepSpeed
  20. Following are the possible scenarios:
  21. Scenario 1 : There is no expert parallelism or model parallelism (D)
  22. model = my_model(args)
  23. engine = deepspeed.init(model) ---> initialize groups without mpu
  24. Scenario 2 : There is expert parallelism but no model parallelism (E+D)
  25. deepspeed.init_groups(args) --> groups will be initialized here
  26. model = my_model(args)
  27. engine = deepspeed.init(model) --> don't initialize groups
  28. Scenario 3 : There is model parallelism but no expert parallelism (M)
  29. mpu.init()
  30. model = my_model(args)
  31. engine = deepspeed.init(model, mpu = mpu) --> initialize groups with mpu but expert_parallel_size = dp_world_size
  32. Scenario 4 : There is model, data, and expert parallelism (E+D+M)
  33. mpu.init()
  34. deepspeed.init_groups(mpu, args) ---> initialize groups with mpu
  35. model = my_model(args)
  36. #Valid but assert inside deepspeed to make sure mpu passed here is same as the one used to init the groups
  37. engine = deepspeed.init(model, mpu = mpu)
  38. #Also Valid
  39. engine = deepspeed.init(model)
  40. """
  41. import torch
  42. from deepspeed.utils import logger, log_dist
  43. # Model parallel group that the current rank belongs to.
  44. _MODEL_PARALLEL_GROUP = None
  45. # Expert parallel group that the current rank belongs to.
  46. _EXPERT_PARALLEL_GROUP = None
  47. # Expert data parallel group that the current rank belongs to.
  48. _EXPERT_DATA_PARALLEL_GROUP = None
  49. # Data parallel group that the current rank belongs to.
  50. _DATA_PARALLEL_GROUP = None
  51. def ensure_divisibility(numerator, denominator):
  52. """Ensure that numerator is divisible by the denominator."""
  53. assert numerator % denominator == 0, '{} is not divisible by {}'.format(
  54. numerator, denominator)
  55. def initialize(ep_size=1, mpu=None):
  56. """
  57. Process groups initialization supporting expert (E), data (D), and model (M) parallelism. DeepSpeed considers
  58. the following scenarios w.r.t. process group creation.
  59. * S1: There is no expert parallelism or model parallelism, only data (D)::
  60. model = my_model(args)
  61. engine = deepspeed.initialize(model) # initialize groups without mpu
  62. * S2: There is expert parallelism but no model parallelism (E+D)::
  63. deepspeed.utils.groups.initialize(ep_size) # groups will be initialized here
  64. model = my_model(args)
  65. engine = deepspeed.initialize(model)
  66. * S3: There is model parallelism but no expert parallelism (M)::
  67. mpu.init() # client initializes it's model parallel unit
  68. model = my_model(args)
  69. engine = deepspeed.initialize(model, mpu=mpu) # init w. mpu but ep_size = dp_world_size
  70. * S4: There is model, data, and expert parallelism (E+D+M)::
  71. mpu.init() # client initializes it's model parallel unit
  72. deepspeed.utils.groups.initialize(ep_size, mpu) # initialize expert groups wrt mpu
  73. model = my_model(args)
  74. engine = deepspeed.initialize(model, mpu=mpu) # passing mpu is optional in this case
  75. Arguments:
  76. ep_size (int, optional): default=1, expert parallel size
  77. mpu (module, optional): default=None, model parallel unit (e.g., from Megatron)
  78. that describes model/data parallel ranks.
  79. """
  80. if mpu is not None:
  81. log_dist(message="initializing deepspeed groups using mpu", ranks=[0])
  82. initialize_model_and_expert_parallel(ep_size, mpu)
  83. else:
  84. log_dist(message="initializing deepspeed groups", ranks=[0])
  85. initialize_model_parallel(1)
  86. initialize_expert_parallel(ep_size)
  87. def initialize_model_parallel(model_parallel_size_):
  88. """
  89. Initialize model data parallel groups.
  90. Arguments:
  91. model_parallel_size: number of GPUs used to parallelize model.
  92. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  93. use 2 GPUs to parallelize the model. The present function will
  94. create 4 model parallel groups and 2 data parallel groups as:
  95. 4 model parallel groups:
  96. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  97. 2 data parallel groups:
  98. [g0, g2, g4, g6], [g1, g3, g5, g7]
  99. Note that for efficiency, the caller should make sure adjacent ranks
  100. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  101. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  102. ranks 8 to 15 belong to the second box.
  103. """
  104. log_dist(
  105. 'initializing deepspeed model parallel group with size {}'.format(
  106. model_parallel_size_),
  107. [0])
  108. # Get world size and rank. Ensure some consistencies.
  109. assert torch.distributed.is_initialized()
  110. world_size = torch.distributed.get_world_size()
  111. model_parallel_size = min(model_parallel_size_, world_size)
  112. ensure_divisibility(world_size, model_parallel_size)
  113. rank = torch.distributed.get_rank()
  114. # Build the data parallel groups.
  115. global _DATA_PARALLEL_GROUP
  116. assert _DATA_PARALLEL_GROUP is None, \
  117. 'data parallel group is already initialized'
  118. for i in range(model_parallel_size):
  119. ranks = range(i, world_size, model_parallel_size)
  120. group = torch.distributed.new_group(ranks)
  121. if i == (rank % model_parallel_size):
  122. _DATA_PARALLEL_GROUP = group
  123. # Build the model parallel groups.
  124. global _MODEL_PARALLEL_GROUP
  125. assert _MODEL_PARALLEL_GROUP is None, \
  126. 'model parallel group is already initialized'
  127. for i in range(world_size // model_parallel_size):
  128. ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
  129. group = torch.distributed.new_group(ranks)
  130. if i == (rank // model_parallel_size):
  131. _MODEL_PARALLEL_GROUP = group
  132. def initialize_expert_parallel(expert_parallel_size_):
  133. """
  134. Initialize expert plus data parallel groups.
  135. Example - E + D parallel
  136. world_size = 16
  137. expert_parallel_size = 2 # number of experts in same group
  138. expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
  139. expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
  140. data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
  141. """
  142. assert torch.distributed.is_initialized()
  143. log_dist(
  144. 'initializing deepspeed expert parallel group with size {}'.format(
  145. expert_parallel_size_),
  146. [0])
  147. world_size = get_data_parallel_world_size()
  148. rank = get_data_parallel_rank()
  149. expert_parallel_size_ = min(expert_parallel_size_, world_size)
  150. ensure_divisibility(world_size, expert_parallel_size_)
  151. # Build the expert data parallel groups.
  152. global _EXPERT_DATA_PARALLEL_GROUP
  153. assert _EXPERT_DATA_PARALLEL_GROUP is None, \
  154. 'expert data parallel group is already initialized'
  155. for i in range(expert_parallel_size_):
  156. ranks = range(i, world_size, expert_parallel_size_)
  157. group = torch.distributed.new_group(ranks)
  158. # TODO: remove
  159. log_dist(
  160. f'creating expert data parallel process group with ranks: {list(ranks)}',
  161. [0])
  162. if i == (rank % expert_parallel_size_):
  163. _EXPERT_DATA_PARALLEL_GROUP = group
  164. # Build the expert parallel groups.
  165. global _EXPERT_PARALLEL_GROUP
  166. assert _EXPERT_PARALLEL_GROUP is None, \
  167. 'expert parallel group is already initialized'
  168. for i in range(world_size // expert_parallel_size_):
  169. ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
  170. group = torch.distributed.new_group(ranks)
  171. # TODO: remove
  172. log_dist(f'creating expert parallel process group with ranks: {list(ranks)}',
  173. [0])
  174. if i == (rank // expert_parallel_size_):
  175. _EXPERT_PARALLEL_GROUP = group
  176. def initialize_model_and_expert_parallel(expert_parallel_size_, mpu):
  177. """
  178. Initialize Expert groups based on MPU groups.
  179. Example - E + M + D parallel
  180. world_size = 16
  181. model_degree = 2
  182. expert_degree = 4 # number of experts in same group
  183. mp_group = [0, 1], [2,3], [4,5] ...
  184. data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15]
  185. expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
  186. expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[]
  187. """
  188. assert torch.distributed.is_initialized(), "torch distributed is not initialized"
  189. assert mpu.model_parallel_is_initialized(), "model parallel group is not initialized"
  190. model_parallel_size_ = mpu.get_model_parallel_world_size()
  191. world_size = torch.distributed.get_world_size()
  192. rank = torch.distributed.get_rank()
  193. dp_world_size = mpu.get_data_parallel_world_size()
  194. dp_rank = mpu.get_data_parallel_rank()
  195. log_dist(
  196. f"Initializing deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, and data parallel size {world_size}",
  197. [0])
  198. global _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
  199. global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP
  200. # Get world size and rank. Ensure some consistencies.
  201. _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
  202. _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()
  203. expert_parallel_size_ = min(expert_parallel_size_, dp_world_size)
  204. ensure_divisibility(world_size, expert_parallel_size_)
  205. # Build the expert data parallel groups.
  206. assert _EXPERT_DATA_PARALLEL_GROUP is None, \
  207. 'expert data parallel group is already initialized'
  208. # Build the expert parallel groups.
  209. assert _EXPERT_PARALLEL_GROUP is None, \
  210. 'expert parallel group is already initialized'
  211. for j in range(model_parallel_size_):
  212. for i in range(expert_parallel_size_):
  213. ranks = range(i * model_parallel_size_ + j,
  214. world_size,
  215. expert_parallel_size_ * model_parallel_size_)
  216. group = torch.distributed.new_group(ranks)
  217. # TODO: remove
  218. log_dist(
  219. f'creating expert data parallel process group with ranks: {list(ranks)}',
  220. [0])
  221. if rank in list(ranks):
  222. _EXPERT_DATA_PARALLEL_GROUP = group
  223. for i in range(dp_world_size // expert_parallel_size_):
  224. ranks = range(i * expert_parallel_size_ * model_parallel_size_ + j,
  225. (i + 1) * expert_parallel_size_ * model_parallel_size_,
  226. model_parallel_size_)
  227. group = torch.distributed.new_group(ranks)
  228. # TODO: remove
  229. log_dist(f'creating expert parallel process group with ranks: {list(ranks)}',
  230. [0])
  231. if rank in list(ranks):
  232. _EXPERT_PARALLEL_GROUP = group
  233. def is_initialized():
  234. """Check if deepspeed groups have been initialized."""
  235. if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _EXPERT_PARALLEL_GROUP is None or _EXPERT_DATA_PARALLEL_GROUP is None:
  236. return False
  237. return True
  238. def model_parallel_is_initialized():
  239. """Check if model and data parallel groups are initialized."""
  240. if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
  241. return False
  242. return True
  243. def expert_parallel_is_initialized():
  244. """Check if expert and expert data parallel groups are initialized."""
  245. if _EXPERT_PARALLEL_GROUP is None or _EXPERT_DATA_PARALLEL_GROUP is None:
  246. return False
  247. return True
  248. def get_model_parallel_group():
  249. """Get the model parallel group the caller rank belongs to."""
  250. assert _MODEL_PARALLEL_GROUP is not None, \
  251. 'model parallel group is not initialized'
  252. return _MODEL_PARALLEL_GROUP
  253. def get_expert_parallel_group():
  254. """Get the expert parallel group the caller rank belongs to."""
  255. assert _EXPERT_PARALLEL_GROUP is not None, \
  256. 'expert parallel group is not initialized'
  257. return _EXPERT_PARALLEL_GROUP
  258. def get_expert_data_parallel_group():
  259. """Get the expert data parallel group the caller rank belongs to."""
  260. assert _EXPERT_DATA_PARALLEL_GROUP is not None, \
  261. 'expert data parallel group is not initialized'
  262. return _EXPERT_DATA_PARALLEL_GROUP
  263. def get_data_parallel_group():
  264. """Get the data parallel group the caller rank belongs to."""
  265. assert _DATA_PARALLEL_GROUP is not None, \
  266. 'data parallel group is not initialized'
  267. return _DATA_PARALLEL_GROUP
  268. def get_model_parallel_world_size():
  269. """Return world size for the model parallel group."""
  270. return torch.distributed.get_world_size(group=get_model_parallel_group())
  271. def get_expert_parallel_world_size():
  272. """Return world size for the expert parallel group."""
  273. return torch.distributed.get_world_size(group=get_expert_parallel_group())
  274. def get_expert_data_parallel_world_size():
  275. """Return world size for the expert data parallel group."""
  276. return torch.distributed.get_world_size(group=get_expert_data_parallel_group())
  277. def get_model_parallel_rank():
  278. """Return my rank for the model parallel group."""
  279. return torch.distributed.get_rank(group=get_model_parallel_group())
  280. def get_expert_parallel_rank():
  281. """Return my rank for the expert parallel group."""
  282. return torch.distributed.get_rank(group=get_expert_parallel_group())
  283. def get_model_parallel_src_rank():
  284. """Calculate the global rank corresponding to a local rank zero
  285. in the model parallel group."""
  286. global_rank = torch.distributed.get_rank()
  287. local_world_size = get_model_parallel_world_size()
  288. return (global_rank // local_world_size) * local_world_size
  289. def get_expert_parallel_src_rank():
  290. """Calculate the global rank corresponding to a local rank zero
  291. in the expert parallel group."""
  292. global_rank = torch.distributed.get_rank()
  293. local_world_size = get_expert_parallel_world_size()
  294. return (global_rank // local_world_size) * local_world_size
  295. def get_expert_data_parallel_rank():
  296. """Return my rank for the expert data parallel group."""
  297. return torch.distributed.get_rank(group=get_expert_data_parallel_group())
  298. def get_data_parallel_world_size():
  299. """Return world size for the data parallel group."""
  300. return torch.distributed.get_world_size(group=get_data_parallel_group())
  301. def get_data_parallel_rank():
  302. """Return my rank for the data parallel group."""
  303. return torch.distributed.get_rank(group=get_data_parallel_group())
  304. def destroy_model_parallel():
  305. """Set the groups to none."""
  306. global _MODEL_PARALLEL_GROUP
  307. _MODEL_PARALLEL_GROUP = None
  308. global _DATA_PARALLEL_GROUP
  309. _DATA_PARALLEL_GROUP = None