zero_checkpoint.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, PARTITION_COUNT)
  6. from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
  7. from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
  8. GROUP_STATE_KEY = 'state'
  9. class ZeROCheckpoint(object):
  10. def __init__(self, dir):
  11. basic_folder_validation(dir)
  12. self.dir = dir
  13. self.file_list = get_zero_files(dir)
  14. self.num_files = len(self.file_list)
  15. assert self.num_files > 0, f'No ZeRO files found in {dir}'
  16. self.src_3d = get_model_3d_descriptor(dir)
  17. self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree,
  18. tp_degree=self.src_3d.tp_degree,
  19. dp_degree=self.src_3d.dp_degree)
  20. self._3d_file_map = self.src_3d.reshape(self.target_3d)
  21. def get_src_world_size(self):
  22. return self.src_3d.world_size()
  23. def get_src_tp_degree(self):
  24. return self.src_3d.tp_degree
  25. def get_src_pp_degree(self):
  26. return self.src_3d.pp_degree
  27. def get_src_dp_degree(self):
  28. return self.src_3d.dp_degree
  29. def get_file_indices_for_rank(self, pp_index, tp_index, dp_index):
  30. assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}'
  31. dp_2d_map = self._3d_file_map[dp_index]
  32. return dp_2d_map.get_data(pp_index, tp_index)
  33. def get_files_for_rank(self, pp_index, tp_index, dp_index):
  34. file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
  35. return [self.file_list[idx] for idx in file_idx_list]
  36. def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], strip_tensor_paddings=True):
  37. state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
  38. merged_sd = None
  39. for state_file in state_file_list:
  40. sd = torch.load(state_file, map_location=torch.device('cpu'))
  41. for key in keys_to_ignore:
  42. sd.pop(key, None)
  43. if strip_tensor_paddings:
  44. self._strip_tensor_paddings(sd)
  45. if merged_sd is None:
  46. merged_sd = sd
  47. else:
  48. merged_sd = merge_state(merged_sd, sd)
  49. self._update_partition_count(merged_sd)
  50. if strip_tensor_paddings:
  51. self._clear_group_paddings(merged_sd)
  52. return merged_sd
  53. def print_3d_index_map(self, tag=None):
  54. if tag:
  55. print(f'3D index map: {tag}')
  56. for dp_index, _2d_map in enumerate(self._3d_file_map):
  57. _2d_map.print_data(f'dp = {dp_index}')
  58. def print_3d_file_map(self, tag=None):
  59. if tag:
  60. print(f'3D file map: {tag}')
  61. for dp_index, _2d_map in enumerate(self._3d_file_map):
  62. for pp_index in _2d_map.pp_degree:
  63. for tp_index in _2d_map.tp_degree:
  64. file_index_list = _2d_map.get_data(pp_index, tp_index)
  65. file_list = [self.file_list[idx] for idx in file_index_list]
  66. print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}')
  67. def reshape(self, target_3d_desc: model_3d_desc):
  68. self.target_3d = target_3d_desc
  69. self._3d_file_map = self.src_3d.reshape(self.target_3d)
  70. def _strip_tensor_paddings(self, sd):
  71. param_group_states = self._get_param_group_states(sd)
  72. if param_group_states is None:
  73. return
  74. group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
  75. if group_paddings is None:
  76. return
  77. for key, group_state in param_group_states.items():
  78. if group_paddings[key] == 0:
  79. continue
  80. for state_name, state_value in group_state.items():
  81. if state_name != "step" and torch.is_tensor(state_value):
  82. raw_length = state_value.numel() - group_paddings[key]
  83. group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
  84. else:
  85. group_state[state_name] = state_value
  86. def _clear_group_paddings(self, sd):
  87. group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
  88. if group_paddings:
  89. num_groups = len(group_paddings)
  90. sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups
  91. def _get_optimizer_state(self, sd, state_key):
  92. optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
  93. if optimizer_state is None:
  94. return None
  95. return optimizer_state.get(state_key, None)
  96. def _get_param_group_states(self, sd):
  97. optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
  98. if optimizer_state is None:
  99. return None
  100. base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None)
  101. if base_optimizer_state is None:
  102. return None
  103. return base_optimizer_state.get(GROUP_STATE_KEY, None)
  104. def _update_partition_count(self, sd):
  105. partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
  106. if partition_counts:
  107. num_groups = len(partition_counts)
  108. sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree] * num_groups