topology.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # Copyright 2019 The Microsoft DeepSpeed Team
  2. from deepspeed.utils import logger
  3. import torch.distributed as dist
  4. import sys
  5. from collections import namedtuple
  6. from itertools import product as cartesian_product
  7. class ProcessTopology:
  8. """ Manages the mapping of n-dimensional Cartesian coordinates to linear
  9. indices. This mapping is used to map the rank of processes to the grid
  10. for various forms of parallelism.
  11. Each axis of the tensor is accessed by its name. The provided ordering
  12. of the axes defines the layout of the topology. ProcessTopology uses a "row-major"
  13. layout of the tensor axes, and so axes=['x', 'y'] would map coordinates (x,y) and
  14. (x,y+1) to adjacent linear indices. If instead axes=['y', 'x'] was used, coordinates
  15. (x,y) and (x+1,y) would be adjacent.
  16. Some methods return ProcessCoord namedtuples.
  17. """
  18. def __init__(self, axes, dims):
  19. """Create a mapping of n-dimensional tensor coordinates to linear indices.
  20. Arguments:
  21. axes (list): the names of the tensor axes
  22. dims (list): the dimension (length) of each axis of the topology tensor
  23. """
  24. self.axes = axes # names of each topology axis
  25. self.dims = dims # length of each topology axis
  26. # This is actually a class that lets us hash {'row':3, 'col':2} mappings
  27. self.ProcessCoord = namedtuple('ProcessCoord', axes)
  28. self.mapping = {}
  29. ranges = [range(d) for d in dims]
  30. # example: 1, (0,0,1)
  31. for global_rank, coord in enumerate(cartesian_product(*ranges)):
  32. key = {axis: coord[self.axes.index(axis)] for axis in self.axes}
  33. key = self.ProcessCoord(**key)
  34. # for example, {ProcessCoord(row=0, col=1) : 1}
  35. self.mapping[key] = global_rank
  36. def get_rank(self, **coord_kwargs):
  37. """Return the global rank of a process via its coordinates.
  38. Coordinates are specified as kwargs. For example:
  39. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  40. >>> X.get_rank(x=0, y=1)
  41. 1
  42. """
  43. if len(coord_kwargs) != len(self.axes):
  44. raise ValueError('get_rank() does not support slices. Use filter_match())')
  45. key = self.ProcessCoord(**coord_kwargs)
  46. assert key in self.mapping, f'key {kwargs} invalid'
  47. return self.mapping[key]
  48. def get_axis_names(self):
  49. """Return a list of the axis names in the ordering of the topology. """
  50. return self.axes
  51. def get_rank_repr(self,
  52. rank,
  53. omit_axes=['data',
  54. 'pipe'],
  55. inner_sep='_',
  56. outer_sep='-'):
  57. """Return a string representation of a rank.
  58. This method is primarily used for checkpointing model data.
  59. For example:
  60. >>> topo = Topo(axes=['a', 'b'], dims=[2, 2])
  61. >>> topo.get_rank_repr(rank=3)
  62. 'a_01-b_01'
  63. >>> topo.get_rank_repr(rank=3, omit_axes=['a'])
  64. 'b_01'
  65. Args:
  66. rank (int): A rank in the topology.
  67. omit_axes (list, optional): Axes that should not be in the representation. Defaults to ['data', 'pipe'].
  68. inner_sep (str, optional): [description]. Defaults to '_'.
  69. outer_sep (str, optional): [description]. Defaults to '-'.
  70. Returns:
  71. str: A string representation of the coordinate owned by ``rank``.
  72. """
  73. omit_axes = frozenset(omit_axes)
  74. axes = [a for a in self.get_axis_names() if a not in omit_axes]
  75. names = []
  76. for ax in axes:
  77. ax_rank = getattr(self.get_coord(rank=rank), ax)
  78. names.append(f'{ax}{inner_sep}{ax_rank:02d}')
  79. return outer_sep.join(names)
  80. def get_dim(self, axis):
  81. """Return the number of processes along the given axis.
  82. For example:
  83. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  84. >>> X.get_dim('y')
  85. 3
  86. """
  87. if axis not in self.axes:
  88. return 0
  89. return self.dims[self.axes.index(axis)]
  90. def get_coord(self, rank):
  91. """Return the coordinate owned by a process rank.
  92. The axes of the returned namedtuple can be directly accessed as members. For
  93. example:
  94. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  95. >>> coord = X.get_coord(rank=1)
  96. >>> coord.x
  97. 0
  98. >>> coord.y
  99. 1
  100. """
  101. for coord, idx in self.mapping.items():
  102. if idx == rank:
  103. return coord
  104. raise ValueError(f'rank {rank} not found in topology.')
  105. def get_axis_comm_lists(self, axis):
  106. """ Construct lists suitable for a communicator group along axis ``axis``.
  107. Example:
  108. >>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  109. >>> topo.get_axis_comm_lists('pipe')
  110. [
  111. [0, 4], # data=0, model=0
  112. [1, 5], # data=0, model=1
  113. [2, 6], # data=1, model=0
  114. [3, 7], # data=1, model=1
  115. ]
  116. Returns:
  117. A list of lists whose coordinates match in all axes *except* ``axis``.
  118. """
  119. # We don't want to RuntimeError because it allows us to write more generalized
  120. # code for hybrid parallelisms.
  121. if axis not in self.axes:
  122. return []
  123. # Grab all axes but `axis`
  124. other_axes = [a for a in self.axes if a != axis]
  125. lists = []
  126. # Construct all combinations of coords with other_axes
  127. ranges = [range(self.get_dim(a)) for a in other_axes]
  128. for coord in cartesian_product(*ranges):
  129. other_keys = {a: coord[other_axes.index(a)] for a in other_axes}
  130. # now go over all ranks in `axis`.
  131. sub_list = []
  132. for axis_key in range(self.get_dim(axis)):
  133. key = self.ProcessCoord(**other_keys, **{axis: axis_key})
  134. sub_list.append(self.mapping[key])
  135. lists.append(sub_list)
  136. return lists
  137. def filter_match(self, **filter_kwargs):
  138. """Return the list of ranks whose coordinates match the provided criteria.
  139. Example:
  140. >>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  141. >>> X.filter_match(pipe=0, data=1)
  142. [2, 3]
  143. >>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)]
  144. [ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]
  145. Arguments:
  146. **filter_kwargs (dict): criteria used to select coordinates.
  147. Returns:
  148. The list of ranks whose coordinates match filter_kwargs.
  149. """
  150. def _filter_helper(x):
  151. for key, val in filter_kwargs.items():
  152. if getattr(x, key) != val:
  153. return False
  154. return True
  155. coords = filter(_filter_helper, self.mapping.keys())
  156. return [self.mapping[coo] for coo in coords]
  157. def get_axis_list(self, axis, idx):
  158. """Returns the list of global ranks whose coordinate in an axis is idx.
  159. For example:
  160. >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
  161. >>> X.get_axis_list(axis='x', idx=0)
  162. [0, 1, 2]
  163. >>> X.get_axis_list(axis='y', idx=0)
  164. [0, 3]
  165. """
  166. # This could be faster by generating the desired keys directly instead of
  167. # filtering.
  168. axis_num = self.axes.index(axis)
  169. ranks = [self.mapping[k] for k in self.mapping.keys() if k[axis_num] == idx]
  170. return ranks
  171. def world_size(self):
  172. return len(self.mapping)
  173. def __str__(self):
  174. return str(self.mapping)
  175. def _prime_factors(N):
  176. """ Returns the prime factorization of positive integer N. """
  177. if N <= 0:
  178. raise ValueError("Values must be strictly positive.")
  179. primes = []
  180. while N != 1:
  181. for candidate in range(2, N + 1):
  182. if N % candidate == 0:
  183. primes.append(candidate)
  184. N //= candidate
  185. break
  186. return primes
  187. class PipeDataParallelTopology(ProcessTopology):
  188. """ A topology specialization for hybrid data and pipeline parallelism.
  189. Uses data parallelism on the last dimension to encourage gradient
  190. reductions to use high-bandwidth intra-node links and lower-volume
  191. pipeline communications to use low-bandwidth inter-node links.
  192. """
  193. def __init__(self, num_pp, num_dp):
  194. super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp])
  195. class PipeModelDataParallelTopology(ProcessTopology):
  196. """ A topology for hybrid pipeline, model, and data parallelism. """
  197. def __init__(self, num_pp, num_mp, num_dp):
  198. super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp])
  199. class PipelineParallelGrid:
  200. """Implements a grid object that stores the data parallel ranks
  201. corresponding to each of the model parallel stages
  202. The grid object organizes the processes in a distributed pytorch job
  203. into a 2D grid, of stage_id and data_parallel_id.
  204. self.stage_id and self.data_parallel_id stores the stage id
  205. and the data parallel id of current process.
  206. self.dp_group groups the processes by stage_id.
  207. self.dp_group[i], is a list containing all process ranks whose
  208. stage_id is i.
  209. self.p2p_groups stores a list of tuple, where each tuple
  210. stores process ranks of adjacent stages for a given data_parallel_id.
  211. For example if num_stage is 5 then a tuple [7,8] represents stages [3, 4],
  212. with data_parallel id = 1. A stage wrap around will appear as non-adjacent ranks,
  213. for example tuple [4,0] with representing wrap-around stage 4 and 0, for
  214. data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0]
  215. for data_parallel_id = 1.
  216. """
  217. def __init__(self, topology=None, process_group=None):
  218. # TODO use process_group if provided
  219. self.global_rank = dist.get_rank()
  220. self.world_size = dist.get_world_size()
  221. if topology is not None:
  222. if self.global_rank == 0:
  223. print('Using topology:', topology)
  224. self._topo = topology
  225. else:
  226. num_pp = 1
  227. num_dp = 1
  228. for idx, prime in enumerate(_prime_factors(self.world_size)):
  229. if idx % 2 == 0:
  230. num_pp *= prime
  231. else:
  232. num_dp *= prime
  233. self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp)
  234. self.data_parallel_size = max(self._topo.get_dim('data'), 1)
  235. self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1)
  236. self.model_parallel_size = max(self._topo.get_dim('model'), 1)
  237. self.slice_parallel_size = self.model_parallel_size
  238. assert self._is_grid_valid(), "Invalid Grid"
  239. self.stage_id = self.get_stage_id()
  240. self.data_parallel_id = self.get_data_parallel_id()
  241. # Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these
  242. # to detect overflow, etc.
  243. self.ds_model_proc_group = None
  244. self.ds_model_rank = -1
  245. for dp in range(self.data_parallel_size):
  246. ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp))
  247. if self.global_rank == 0:
  248. #print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}')
  249. pass
  250. proc_group = dist.new_group(ranks=ranks)
  251. if self.global_rank in ranks:
  252. self.ds_model_proc_group = proc_group
  253. self.ds_model_world_size = len(ranks)
  254. self.ds_model_rank = ranks.index(self.global_rank)
  255. assert self.ds_model_rank > -1
  256. assert self.ds_model_proc_group is not None
  257. # Create new ProcessGroup for gradient all-reduces - these are the data parallel groups
  258. self.dp_group = []
  259. self.dp_groups = self._topo.get_axis_comm_lists('data')
  260. for g in self.dp_groups:
  261. proc_group = dist.new_group(ranks=g)
  262. if self.global_rank in g:
  263. self.dp_group = g
  264. self.dp_proc_group = proc_group
  265. self.is_first_stage = (self.stage_id == 0)
  266. self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1))
  267. self.p2p_groups = self._build_p2p_groups()
  268. # Create new ProcessGroup for pipeline collectives - these are pipe parallel groups
  269. self.pp_group = []
  270. self.pp_proc_group = None
  271. self.pipe_groups = self._topo.get_axis_comm_lists('pipe')
  272. for ranks in self.pipe_groups:
  273. if self.global_rank == 0:
  274. #print(f'RANK={self.global_rank} building pipeline group: {ranks}')
  275. pass
  276. proc_group = dist.new_group(ranks=ranks)
  277. if self.global_rank in ranks:
  278. self.pp_group = ranks
  279. self.pp_proc_group = proc_group
  280. assert self.pp_proc_group is not None
  281. # Create new ProcessGroup for model (tensor-slicing) collectives
  282. # Short circuit case without model parallelism.
  283. # TODO: it would be nice if topology had bcast semantics to avoid this branching
  284. # case?
  285. if self.model_parallel_size == 1:
  286. for group_rank in range(self.world_size):
  287. group_rank = [group_rank]
  288. group = dist.new_group(ranks=group_rank)
  289. if group_rank[0] == self.global_rank:
  290. self.slice_group = group_rank
  291. self.slice_proc_group = group
  292. return
  293. else:
  294. self.mp_group = []
  295. self.model_groups = self._topo.get_axis_comm_lists('model')
  296. for g in self.model_groups:
  297. proc_group = dist.new_group(ranks=g)
  298. if self.global_rank in g:
  299. self.slice_group = g
  300. self.slice_proc_group = proc_group
  301. def get_stage_id(self):
  302. return self._topo.get_coord(rank=self.global_rank).pipe
  303. def get_data_parallel_id(self):
  304. return self._topo.get_coord(rank=self.global_rank).data
  305. def _build_p2p_groups(self):
  306. """Groups for sending and receiving activations and gradients across model
  307. parallel stages.
  308. """
  309. comm_lists = self._topo.get_axis_comm_lists('pipe')
  310. p2p_lists = []
  311. for rank in range(self.world_size):
  312. for l in comm_lists:
  313. assert len(l) == self.pipe_parallel_size
  314. if rank in l:
  315. idx = l.index(rank)
  316. buddy_rank = l[(idx + 1) % self.pipe_parallel_size]
  317. p2p_lists.append([rank, buddy_rank])
  318. break # next global rank
  319. assert len(p2p_lists) == self.world_size
  320. return p2p_lists
  321. def _is_grid_valid(self):
  322. ranks = 1
  323. for ax in self._topo.get_axis_names():
  324. ranks *= self._topo.get_dim(ax)
  325. return ranks == dist.get_world_size()
  326. #returns the global rank of the process with the provided stage id
  327. #which has the same data_parallel_id as caller process
  328. def stage_to_global(self, stage_id, **kwargs):
  329. me = self._topo.get_coord(self.global_rank)
  330. transform = me._replace(pipe=stage_id, **kwargs)._asdict()
  331. return self._topo.get_rank(**transform)
  332. def topology(self):
  333. return self._topo
  334. # MPU functions for DeepSpeed integration
  335. def get_global_rank(self):
  336. return self.global_rank
  337. def get_pipe_parallel_rank(self):
  338. """ The stage of the pipeline this rank resides in. """
  339. return self.get_stage_id()
  340. def get_pipe_parallel_world_size(self):
  341. """ The number of stages in the pipeline. """
  342. return self.pipe_parallel_size
  343. def get_pipe_parallel_group(self):
  344. """ The group of ranks within the same pipeline. """
  345. return self.pp_proc_group
  346. def get_data_parallel_rank(self):
  347. """ Which pipeline this rank resides in. """
  348. return self.data_parallel_id
  349. def get_data_parallel_world_size(self):
  350. """ The number of pipelines. """
  351. return self.data_parallel_size
  352. def get_data_parallel_group(self):
  353. """ The group of ranks within the same stage of all pipelines. """
  354. return self.dp_proc_group
  355. # These are model parallel groups across all types of model parallelism.
  356. # Deepspeed uses them to detect overflow, etc.
  357. def get_model_parallel_rank(self):
  358. return self.ds_model_rank
  359. def get_model_parallel_world_size(self):
  360. return self.ds_model_world_size
  361. def get_model_parallel_group(self):
  362. return self.ds_model_proc_group
  363. # For Megatron-style tensor slicing
  364. def get_slice_parallel_rank(self):
  365. if 'model' in self._topo.get_axis_names():
  366. return self._topo.get_coord(rank=self.global_rank).model
  367. else:
  368. return 0
  369. def get_slice_parallel_world_size(self):
  370. return self.slice_parallel_size
  371. def get_slice_parallel_group(self):
  372. return self.slice_proc_group