reshape_3d_utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .reshape_utils import (get_files,
  3. get_files_with_prefix,
  4. partition_data,
  5. get_zero_files)
  6. from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
  7. from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map)
  8. PP_DIM = 'PP'
  9. TP_DIM = 'TP'
  10. DP_DIM = 'DP'
  11. class model_3d_desc(object):
  12. def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1):
  13. self.pp_degree = pp_degree
  14. self.tp_degree = tp_degree
  15. self.dp_degree = dp_degree
  16. def reshape(self, target_3d_desc, verbose=False):
  17. valid_reshape, reshape_errors = self.can_reshape(target_3d_desc)
  18. assert valid_reshape, ','.join(reshape_errors)
  19. tgt_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.pp_degree,
  20. old_tp_degree=self.tp_degree,
  21. new_pp_degree=target_3d_desc.pp_degree,
  22. new_tp_degree=target_3d_desc.tp_degree,
  23. verbose=verbose)
  24. flat_3d_map = flatten_dp_dimension(meg_2d_map=tgt_2d_map,
  25. src_2d_size=self.pp_degree * self.tp_degree,
  26. dp_degree=self.dp_degree)
  27. return unflatten_dp_dimension(meg_2d_map=flat_3d_map,
  28. dp_degree=target_3d_desc.dp_degree)
  29. def get_desc(self):
  30. return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'
  31. def world_size(self):
  32. return self.pp_degree * self.tp_degree * self.dp_degree
  33. def is_valid(self, pp_index, tp_index, dp_index):
  34. err_msg = []
  35. valid = True
  36. for index, degree, dim_name in [
  37. (pp_index, self.pp_degree, PP_DIM),
  38. (tp_index, self.tp_degree, TP_DIM),
  39. (dp_index, self.dp_degree, DP_DIM)]:
  40. if index >= degree:
  41. valid = False
  42. err_msg.append(
  43. f'{dim_name} indexing error: index {index} >= degree {degree}')
  44. return valid, err_msg
  45. def can_reshape(self, target_3d_desc):
  46. err_msg = []
  47. if target_3d_desc.pp_degree > self.pp_degree:
  48. err_msg.append(
  49. f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}'
  50. )
  51. if target_3d_desc.tp_degree > self.tp_degree:
  52. err_msg.append(
  53. f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}'
  54. )
  55. if target_3d_desc.dp_degree > self.dp_degree:
  56. err_msg.append(
  57. f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}'
  58. )
  59. return len(err_msg) == 0, err_msg
  60. def get_model_3d_descriptor(dir):
  61. file_list = get_files(dir)
  62. zero_file_list = get_zero_files(dir)
  63. num_pp0_files = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01'))
  64. if num_pp0_files > 0:
  65. tp_degree = num_pp0_files
  66. pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree
  67. dp_degree = max(1, len(zero_file_list) // (pp_degree * tp_degree))
  68. else:
  69. tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX))
  70. dp_degree = max(1, len(zero_file_list) // tp_degree)
  71. pp_degree = 0
  72. return model_3d_desc(pp_degree, tp_degree, dp_degree)
  73. def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree):
  74. new_meg_2d_map = meg_2d_parallel_map(meg_2d_map.pp_degree, meg_2d_map.tp_degree)
  75. for pp_index in range(meg_2d_map.pp_degree):
  76. for tp_index in range(meg_2d_map.tp_degree):
  77. dp0_indices = meg_2d_map.get_data(pp_index, tp_index)
  78. for idx in dp0_indices:
  79. dpX_indices = [idx + (i * src_2d_size) for i in range(dp_degree)]
  80. new_meg_2d_map.add_data(pp_index, tp_index, dpX_indices)
  81. return new_meg_2d_map
  82. def unflatten_dp_dimension(meg_2d_map, dp_degree):
  83. pp_degree = meg_2d_map.pp_degree
  84. tp_degree = meg_2d_map.tp_degree
  85. meg_2d_map_list = [
  86. meg_2d_parallel_map(pp_degree=pp_degree,
  87. tp_degree=tp_degree) for _ in range(dp_degree)
  88. ]
  89. for pp_index in range(pp_degree):
  90. for tp_index in range(tp_degree):
  91. flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index)
  92. partitioned_dp_indices = partition_data(flat_dp_indices, dp_degree)
  93. for dp_indices, _2d_map in zip(partitioned_dp_indices, meg_2d_map_list):
  94. _2d_map.add_data(pp_index, tp_index, dp_indices)
  95. return meg_2d_map_list