topology.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import comm as dist
  5. from collections import namedtuple
  6. from itertools import product as cartesian_product
  7. class ProcessTopology:
  8. """ Manages the mapping of n-dimensional Cartesian coordinates to linear
  9. indices. This mapping is used to map the rank of processes to the grid
  10. for various forms of parallelism.
  11. Each axis of the tensor is accessed by its name. The provided ordering
  12. of the axes defines the layout of the topology. ProcessTopology uses a "row-major"
  13. layout of the tensor axes, and so axes=['x', 'y'] would map coordinates (x,y) and
  14. (x,y+1) to adjacent linear indices. If instead axes=['y', 'x'] was used, coordinates
  15. (x,y) and (x+1,y) would be adjacent.
  16. Some methods return ProcessCoord namedtuples.
  17. """
  18. def __init__(self, axes, dims):
  19. """Create a mapping of n-dimensional tensor coordinates to linear indices.
  20. Arguments:
  21. axes (list): the names of the tensor axes
  22. dims (list): the dimension (length) of each axis of the topology tensor
  23. """
  24. self.axes = axes # names of each topology axis
  25. self.dims = dims # length of each topology axis
  26. # This is actually a class that lets us hash {'row':3, 'col':2} mappings
  27. self.ProcessCoord = namedtuple('ProcessCoord', axes)
  28. self.mapping = {}
  29. ranges = [range(d) for d in dims]
  30. # example: 1, (0,0,1)
  31. for global_rank, coord in enumerate(cartesian_product(*ranges)):
  32. key = {axis: coord[self.axes.index(axis)] for axis in self.axes}
  33. key = self.ProcessCoord(**key)
  34. # for example, {ProcessCoord(row=0, col=1) : 1}
  35. self.mapping[key] = global_rank
  36. def get_rank(self, **coord_kwargs):
  37. """Return the global rank of a process via its coordinates.
  38. Coordinates are specified as kwargs. For example:
  39. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  40. >>> X.get_rank(x=0, y=1)
  41. 1
  42. """
  43. if len(coord_kwargs) != len(self.axes):
  44. raise ValueError('get_rank() does not support slices. Use filter_match())')
  45. key = self.ProcessCoord(**coord_kwargs)
  46. assert key in self.mapping, f'key {coord_kwargs} invalid'
  47. return self.mapping[key]
  48. def get_axis_names(self):
  49. """Return a list of the axis names in the ordering of the topology. """
  50. return self.axes
  51. def get_rank_repr(self, rank, omit_axes=['data', 'pipe'], inner_sep='_', outer_sep='-'):
  52. """Return a string representation of a rank.
  53. This method is primarily used for checkpointing model data.
  54. For example:
  55. >>> topo = Topo(axes=['a', 'b'], dims=[2, 2])
  56. >>> topo.get_rank_repr(rank=3)
  57. 'a_01-b_01'
  58. >>> topo.get_rank_repr(rank=3, omit_axes=['a'])
  59. 'b_01'
  60. Args:
  61. rank (int): A rank in the topology.
  62. omit_axes (list, optional): Axes that should not be in the representation. Defaults to ['data', 'pipe'].
  63. inner_sep (str, optional): [description]. Defaults to '_'.
  64. outer_sep (str, optional): [description]. Defaults to '-'.
  65. Returns:
  66. str: A string representation of the coordinate owned by ``rank``.
  67. """
  68. omit_axes = frozenset(omit_axes)
  69. axes = [a for a in self.get_axis_names() if a not in omit_axes]
  70. names = []
  71. for ax in axes:
  72. ax_rank = getattr(self.get_coord(rank=rank), ax)
  73. names.append(f'{ax}{inner_sep}{ax_rank:02d}')
  74. return outer_sep.join(names)
  75. def get_dim(self, axis):
  76. """Return the number of processes along the given axis.
  77. For example:
  78. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  79. >>> X.get_dim('y')
  80. 3
  81. """
  82. if axis not in self.axes:
  83. return 0
  84. return self.dims[self.axes.index(axis)]
  85. def get_coord(self, rank):
  86. """Return the coordinate owned by a process rank.
  87. The axes of the returned namedtuple can be directly accessed as members. For
  88. example:
  89. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  90. >>> coord = X.get_coord(rank=1)
  91. >>> coord.x
  92. 0
  93. >>> coord.y
  94. 1
  95. """
  96. for coord, idx in self.mapping.items():
  97. if idx == rank:
  98. return coord
  99. raise ValueError(f'rank {rank} not found in topology.')
  100. def get_axis_comm_lists(self, axis):
  101. """ Construct lists suitable for a communicator group along axis ``axis``.
  102. Example:
  103. >>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  104. >>> topo.get_axis_comm_lists('pipe')
  105. [
  106. [0, 4], # data=0, model=0
  107. [1, 5], # data=0, model=1
  108. [2, 6], # data=1, model=0
  109. [3, 7], # data=1, model=1
  110. ]
  111. Returns:
  112. A list of lists whose coordinates match in all axes *except* ``axis``.
  113. """
  114. # We don't want to RuntimeError because it allows us to write more generalized
  115. # code for hybrid parallelisms.
  116. if axis not in self.axes:
  117. return []
  118. # Grab all axes but `axis`
  119. other_axes = [a for a in self.axes if a != axis]
  120. lists = []
  121. # Construct all combinations of coords with other_axes
  122. ranges = [range(self.get_dim(a)) for a in other_axes]
  123. for coord in cartesian_product(*ranges):
  124. other_keys = {a: coord[other_axes.index(a)] for a in other_axes}
  125. # now go over all ranks in `axis`.
  126. sub_list = []
  127. for axis_key in range(self.get_dim(axis)):
  128. key = self.ProcessCoord(**other_keys, **{axis: axis_key})
  129. sub_list.append(self.mapping[key])
  130. lists.append(sub_list)
  131. return lists
  132. def filter_match(self, **filter_kwargs):
  133. """Return the list of ranks whose coordinates match the provided criteria.
  134. Example:
  135. >>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  136. >>> X.filter_match(pipe=0, data=1)
  137. [2, 3]
  138. >>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)]
  139. [ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]
  140. Arguments:
  141. **filter_kwargs (dict): criteria used to select coordinates.
  142. Returns:
  143. The list of ranks whose coordinates match filter_kwargs.
  144. """
  145. def _filter_helper(x):
  146. for key, val in filter_kwargs.items():
  147. if getattr(x, key) != val:
  148. return False
  149. return True
  150. coords = filter(_filter_helper, self.mapping.keys())
  151. return [self.mapping[coord] for coord in coords]
  152. def get_axis_list(self, axis, idx):
  153. """Returns the list of global ranks whose coordinate in an axis is idx.
  154. For example:
  155. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  156. >>> X.get_axis_list(axis='x', idx=0)
  157. [0, 1, 2]
  158. >>> X.get_axis_list(axis='y', idx=0)
  159. [0, 3]
  160. """
  161. # This could be faster by generating the desired keys directly instead of
  162. # filtering.
  163. axis_num = self.axes.index(axis)
  164. ranks = [self.mapping[k] for k in self.mapping.keys() if k[axis_num] == idx]
  165. return ranks
  166. def world_size(self):
  167. return len(self.mapping)
  168. def __str__(self):
  169. return str(self.mapping)
  170. def _prime_factors(N):
  171. """ Returns the prime factorization of positive integer N. """
  172. if N <= 0:
  173. raise ValueError("Values must be strictly positive.")
  174. primes = []
  175. while N != 1:
  176. for candidate in range(2, N + 1):
  177. if N % candidate == 0:
  178. primes.append(candidate)
  179. N //= candidate
  180. break
  181. return primes
  182. class PipeDataParallelTopology(ProcessTopology):
  183. """ A topology specialization for hybrid data and pipeline parallelism.
  184. Uses data parallelism on the last dimension to encourage gradient
  185. reductions to use high-bandwidth intra-node links and lower-volume
  186. pipeline communications to use low-bandwidth inter-node links.
  187. """
  188. def __init__(self, num_pp, num_dp):
  189. super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp])
  190. class PipeModelDataParallelTopology(ProcessTopology):
  191. """ A topology for hybrid pipeline, model, and data parallelism. """
  192. def __init__(self, num_pp, num_mp, num_dp):
  193. super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp])
  194. class PipelineParallelGrid:
  195. """Implements a grid object that stores the data parallel ranks
  196. corresponding to each of the model parallel stages
  197. The grid object organizes the processes in a distributed pytorch job
  198. into a 2D grid, of stage_id and data_parallel_id.
  199. self.stage_id and self.data_parallel_id stores the stage id
  200. and the data parallel id of current process.
  201. self.dp_group groups the processes by stage_id.
  202. self.dp_group[i], is a list containing all process ranks whose
  203. stage_id is i.
  204. self.p2p_groups stores a list of tuple, where each tuple
  205. stores process ranks of adjacent stages for a given data_parallel_id.
  206. For example if num_stage is 5 then a tuple [7,8] represents stages [3, 4],
  207. with data_parallel id = 1. A stage wrap around will appear as non-adjacent ranks,
  208. for example tuple [4,0] with representing wrap-around stage 4 and 0, for
  209. data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0]
  210. for data_parallel_id = 1.
  211. """
  212. def __init__(self, topology=None, process_group=None):
  213. # TODO use process_group if provided
  214. self.global_rank = dist.get_rank()
  215. self.world_size = dist.get_world_size()
  216. if topology is not None:
  217. if self.global_rank == 0:
  218. print('Using topology:', topology)
  219. self._topo = topology
  220. else:
  221. num_pp = 1
  222. num_dp = 1
  223. for idx, prime in enumerate(_prime_factors(self.world_size)):
  224. if idx % 2 == 0:
  225. num_pp *= prime
  226. else:
  227. num_dp *= prime
  228. self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp)
  229. self.data_parallel_size = max(self._topo.get_dim('data'), 1)
  230. self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1)
  231. self.model_parallel_size = max(self._topo.get_dim('model'), 1)
  232. self.slice_parallel_size = self.model_parallel_size
  233. assert self._is_grid_valid(), "Invalid Grid"
  234. self.stage_id = self.get_stage_id()
  235. self.data_parallel_id = self.get_data_parallel_id()
  236. # Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these
  237. # to detect overflow, etc.
  238. self.ds_model_proc_group = None
  239. self.ds_model_rank = -1
  240. for dp in range(self.data_parallel_size):
  241. ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp))
  242. if self.global_rank == 0:
  243. #print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}')
  244. pass
  245. proc_group = dist.new_group(ranks=ranks)
  246. if self.global_rank in ranks:
  247. self.ds_model_proc_group = proc_group
  248. self.ds_model_world_size = len(ranks)
  249. self.ds_model_rank = ranks.index(self.global_rank)
  250. assert self.ds_model_rank > -1
  251. assert self.ds_model_proc_group is not None
  252. # Create new ProcessGroup for gradient all-reduces - these are the data parallel groups
  253. self.dp_group = []
  254. self.dp_groups = self._topo.get_axis_comm_lists('data')
  255. for g in self.dp_groups:
  256. proc_group = dist.new_group(ranks=g)
  257. if self.global_rank in g:
  258. self.dp_group = g
  259. self.dp_proc_group = proc_group
  260. self.is_first_stage = (self.stage_id == 0)
  261. self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1))
  262. self.p2p_groups = self._build_p2p_groups()
  263. # Create new ProcessGroup for pipeline collectives - these are pipe parallel groups
  264. self.pp_group = []
  265. self.pp_proc_group = None
  266. self.pipe_groups = self._topo.get_axis_comm_lists('pipe')
  267. for ranks in self.pipe_groups:
  268. if self.global_rank == 0:
  269. #print(f'RANK={self.global_rank} building pipeline group: {ranks}')
  270. pass
  271. proc_group = dist.new_group(ranks=ranks)
  272. if self.global_rank in ranks:
  273. self.pp_group = ranks
  274. self.pp_proc_group = proc_group
  275. assert self.pp_proc_group is not None
  276. # Create new ProcessGroup for model (tensor-slicing) collectives
  277. # Short circuit case without model parallelism.
  278. # TODO: it would be nice if topology had bcast semantics to avoid this branching
  279. # case?
  280. if self.model_parallel_size == 1:
  281. for group_rank in range(self.world_size):
  282. group_rank = [group_rank]
  283. group = dist.new_group(ranks=group_rank)
  284. if group_rank[0] == self.global_rank:
  285. self.slice_group = group_rank
  286. self.slice_proc_group = group
  287. return
  288. else:
  289. self.mp_group = []
  290. self.model_groups = self._topo.get_axis_comm_lists('model')
  291. for g in self.model_groups:
  292. proc_group = dist.new_group(ranks=g)
  293. if self.global_rank in g:
  294. self.slice_group = g
  295. self.slice_proc_group = proc_group
  296. def get_stage_id(self):
  297. return self._topo.get_coord(rank=self.global_rank).pipe
  298. def get_data_parallel_id(self):
  299. return self._topo.get_coord(rank=self.global_rank).data
  300. def _build_p2p_groups(self):
  301. """Groups for sending and receiving activations and gradients across model
  302. parallel stages.
  303. """
  304. comm_lists = self._topo.get_axis_comm_lists('pipe')
  305. p2p_lists = []
  306. for rank in range(self.world_size):
  307. for l in comm_lists:
  308. assert len(l) == self.pipe_parallel_size
  309. if rank in l:
  310. idx = l.index(rank)
  311. buddy_rank = l[(idx + 1) % self.pipe_parallel_size]
  312. p2p_lists.append([rank, buddy_rank])
  313. break # next global rank
  314. assert len(p2p_lists) == self.world_size
  315. return p2p_lists
  316. def _is_grid_valid(self):
  317. ranks = 1
  318. for ax in self._topo.get_axis_names():
  319. ranks *= self._topo.get_dim(ax)
  320. return ranks == dist.get_world_size()
  321. #returns the global rank of the process with the provided stage id
  322. #which has the same data_parallel_id as caller process
  323. def stage_to_global(self, stage_id, **kwargs):
  324. me = self._topo.get_coord(self.global_rank)
  325. transform = me._replace(pipe=stage_id, **kwargs)._asdict()
  326. return self._topo.get_rank(**transform)
  327. def topology(self):
  328. return self._topo
  329. # MPU functions for DeepSpeed integration
  330. def get_global_rank(self):
  331. return self.global_rank
  332. def get_pipe_parallel_rank(self):
  333. """ The stage of the pipeline this rank resides in. """
  334. return self.get_stage_id()
  335. def get_pipe_parallel_world_size(self):
  336. """ The number of stages in the pipeline. """
  337. return self.pipe_parallel_size
  338. def get_pipe_parallel_group(self):
  339. """ The group of ranks within the same pipeline. """
  340. return self.pp_proc_group
  341. def get_data_parallel_rank(self):
  342. """ Which pipeline this rank resides in. """
  343. return self.data_parallel_id
  344. def get_data_parallel_world_size(self):
  345. """ The number of pipelines. """
  346. return self.data_parallel_size
  347. def get_data_parallel_group(self):
  348. """ The group of ranks within the same stage of all pipelines. """
  349. return self.dp_proc_group
  350. # These are model parallel groups across all types of model parallelism.
  351. # Deepspeed uses them to detect overflow, etc.
  352. def get_model_parallel_rank(self):
  353. return self.ds_model_rank
  354. def get_model_parallel_world_size(self):
  355. return self.ds_model_world_size
  356. def get_model_parallel_group(self):
  357. return self.ds_model_proc_group
  358. # For Megatron-style tensor slicing
  359. def get_slice_parallel_rank(self):
  360. if 'model' in self._topo.get_axis_names():
  361. return self._topo.get_coord(rank=self.global_rank).model
  362. else:
  363. return 0
  364. def get_slice_parallel_world_size(self):
  365. return self.slice_parallel_size
  366. def get_slice_parallel_group(self):
  367. return self.slice_proc_group