reshape_meg_2d.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .reshape_utils import partition_data
  3. class meg_2d_parallel_map(object):
  4. def __init__(self, pp_degree, tp_degree):
  5. self.pp_degree = pp_degree
  6. self.tp_degree = tp_degree
  7. self.map = {}
  8. def simple_init(self):
  9. self.map = {
  10. self._make_key(i // self.tp_degree,
  11. i % self.tp_degree): [i]
  12. for i in range(self.pp_degree * self.tp_degree)
  13. }
  14. def add_data(self, pp_index, tp_index, data):
  15. self._validate_indices(pp_index, tp_index)
  16. assert type(data) is list
  17. key = self._make_key(pp_index, tp_index)
  18. if not key in self.map.keys():
  19. self.map[key] = []
  20. self.map[key] += data
  21. def get_data(self, pp_index=None, tp_index=None):
  22. self._validate_indices(pp_index, tp_index)
  23. pp_indices = list(range(self.pp_degree)) if pp_index is None else [pp_index]
  24. tp_indices = list(range(self.tp_degree)) if tp_index is None else [tp_index]
  25. result = []
  26. for i in pp_indices:
  27. for j in tp_indices:
  28. result += self.map[self._make_key(i, j)]
  29. return result
  30. def print_data(self, tag):
  31. print(f'{tag}')
  32. for key, value in self.map.items():
  33. print(f'{key} = {value}')
  34. def _validate_indices(self, pp_index, tp_index):
  35. assert pp_index is None or pp_index < self.pp_degree
  36. assert tp_index is None or tp_index < self.tp_degree
  37. def _make_key(self, i, j):
  38. return f'{i},{j}'
  39. def _reshape_tp_dimension(old_2d_map, new_tp_degree):
  40. old_pp_degree = old_2d_map.pp_degree
  41. new_2d_map = meg_2d_parallel_map(old_pp_degree, new_tp_degree)
  42. for i in range(old_pp_degree):
  43. ranks_for_pp_index = old_2d_map.get_data(pp_index=i, tp_index=None)
  44. split_ranks = partition_data(ranks_for_pp_index, new_tp_degree)
  45. for j in range(new_tp_degree):
  46. new_2d_map.add_data(i, j, split_ranks[j])
  47. return new_2d_map
  48. def _reshape_pp_dimension(old_2d_map, new_pp_degree):
  49. old_tp_degree = old_2d_map.tp_degree
  50. new_2d_map = meg_2d_parallel_map(new_pp_degree, old_tp_degree)
  51. for i in range(old_tp_degree):
  52. ranks_for_tp_index = old_2d_map.get_data(pp_index=None, tp_index=i)
  53. split_ranks = partition_data(ranks_for_tp_index, new_pp_degree)
  54. for j in range(new_pp_degree):
  55. new_2d_map.add_data(j, i, split_ranks[j])
  56. return new_2d_map
  57. def reshape_meg_2d_parallel(old_pp_degree,
  58. old_tp_degree,
  59. new_pp_degree,
  60. new_tp_degree,
  61. verbose=False):
  62. assert new_pp_degree <= old_pp_degree
  63. assert new_tp_degree <= old_tp_degree
  64. old_2d_map = meg_2d_parallel_map(old_pp_degree, old_tp_degree)
  65. old_2d_map.simple_init()
  66. if verbose:
  67. old_2d_map.print_data(f'original_2d_map:')
  68. if old_tp_degree != new_tp_degree:
  69. new_tp_map = _reshape_tp_dimension(old_2d_map, new_tp_degree)
  70. else:
  71. new_tp_map = old_2d_map
  72. if verbose:
  73. new_tp_map.print_data(f'after_tp_reshape:')
  74. if old_pp_degree != new_pp_degree:
  75. final_map = _reshape_pp_dimension(new_tp_map, new_pp_degree)
  76. else:
  77. final_map = new_tp_map
  78. if verbose:
  79. final_map.print_data(f'final_2d_map:')
  80. return final_map
  81. def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
  82. """
  83. Initialize model data parallel groups.
  84. Arguments:
  85. tp_size: number of GPUs used to parallelize model tensor.
  86. pp_size: number of GPUs used to parallelize model pipeline.
  87. dp_size: number of GPUs used to parallelize model data.
  88. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
  89. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  90. the model pipeline. The present function will
  91. create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
  92. and 8 data-parallel groups as:
  93. 8 data_parallel groups:
  94. [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
  95. 8 tensor model-parallel groups:
  96. [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
  97. 4 pipeline model-parallel groups:
  98. [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
  99. Note that for efficiency, the caller should make sure adjacent ranks
  100. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  101. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  102. ranks 8 to 15 belong to the second box.
  103. """
  104. world_size = tp_size * pp_size * dp_size
  105. print(f"\n\n*** tp={tp_size}, pp={pp_size}, dp={dp_size}, world={world_size}")
  106. tensor_model_parallel_size = min(tp_size, world_size)
  107. pipeline_model_parallel_size = min(pp_size, world_size)
  108. data_parallel_size = world_size // (tensor_model_parallel_size *
  109. pipeline_model_parallel_size)
  110. num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
  111. num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
  112. num_data_parallel_groups = world_size // data_parallel_size
  113. # Build the data-parallel groups.
  114. all_dp_group_ranks = []
  115. for i in range(pipeline_model_parallel_size):
  116. start_rank = i * num_pipeline_model_parallel_groups
  117. end_rank = (i + 1) * num_pipeline_model_parallel_groups
  118. for j in range(tensor_model_parallel_size):
  119. ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
  120. all_dp_group_ranks.append(list(ranks))
  121. print("DP", all_dp_group_ranks)
  122. # Build the model-parallel groups.
  123. all_pp_group_ranks = []
  124. for i in range(data_parallel_size):
  125. ranks = [
  126. data_parallel_group_ranks[i]
  127. for data_parallel_group_ranks in all_dp_group_ranks
  128. ]
  129. all_pp_group_ranks.append(list(ranks))
  130. print(f"PP", all_pp_group_ranks)
  131. # Build the tensor model-parallel groups.
  132. all_tp_group_ranks = []
  133. for i in range(num_tensor_model_parallel_groups):
  134. ranks = range(i * tensor_model_parallel_size,
  135. (i + 1) * tensor_model_parallel_size)
  136. all_tp_group_ranks.append(list(ranks))
  137. print(f"TP", all_tp_group_ranks)
  138. return all_tp_group_ranks, all_pp_group_ranks, all_dp_group_ranks
  139. # # Build the pipeline model-parallel groups and embedding groups
  140. # # (first and last rank in each pipeline model-parallel group).
  141. # for i in range(num_pipeline_model_parallel_groups):
  142. # ranks = range(i, world_size,
  143. # num_pipeline_model_parallel_groups)
  144. # print(f"EMB{i}", list(ranks))
  145. def reshape(src, tgt):
  146. """
  147. reshape([tp_size_src, pp_size_src, dp_size_src],
  148. [tp_size_tgt, pp_size_tgt, dp_size_tgt])
  149. """
  150. print(f"\n\n*** Reshaping: {src} => {tgt}")
  151. tp_size_src, pp_size_src, dp_size_src = src
  152. tp_size_tgt, pp_size_tgt, dp_size_tgt = tgt
  153. tp_ranks1, pp_ranks1, dp_ranks1 = get_mpu_ranks(tp_size=tp_size_src, pp_size=pp_size_src, dp_size=dp_size_src)
  154. tp_ranks2, pp_ranks2, dp_ranks2 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_src, dp_size=dp_size_src)
  155. tp_ranks3, pp_ranks3, dp_ranks3 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_src)
  156. # handle tp contraction first
  157. print("\n*** TP contraction:")
  158. for i, r in enumerate(tp_ranks1):
  159. print(f'{tp_ranks1[i]} => {tp_ranks2[i]}')
  160. # handle pp contraction next
  161. print("\n*** PP contraction:")
  162. for i, r in enumerate(pp_ranks1):
  163. print(f'{pp_ranks2[i]} => {pp_ranks3[i]}')
  164. # easy
  165. #reshape([2,2,1],[1,1,1])
  166. # probably need more logic to suggest how to pack
  167. #reshape([4,4,1],[2,2,1])
  168. #reshape([2,4,2], [8,32,1])
  169. # get_mpu_ranks(2,2,2)
  170. # get_mpu_ranks(4,2,1)
  171. # get_mpu_ranks(2,4,1)
  172. # get_mpu_ranks(1,1,8)