universal_checkpoint.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import torch
  7. import types
  8. from typing import List, Tuple, Union
  9. from dataclasses import dataclass
  10. from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)
  11. @dataclass
  12. class SubparamShape:
  13. patterns: List[str]
  14. shape: Tuple[Union[Tuple[int], int]]
  15. partition_dim: int
  16. def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
  17. hp_mapping = self._hp_mapping
  18. hp_mapping.optim_fragment = {}
  19. hp_keys = []
  20. for file in os.listdir(folder):
  21. # We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
  22. pattern = r'(.+).pt'
  23. match = re.search(pattern, file)
  24. if match:
  25. hp_keys.append(match.group(1))
  26. step = None
  27. for key in hp_keys:
  28. ckpt_file = os.path.join(folder, f"{key}.pt")
  29. ckpt_dict = torch.load(ckpt_file)
  30. if key == "step":
  31. step = ckpt_dict
  32. continue
  33. full_hp_param = ckpt_dict[PARAM]
  34. # need to deal with slices that were averaged.
  35. # the opposite of averaging here becomes an exact copy of the first slice
  36. # I thought of 2 ways:
  37. # implementation a. find a way for a client to pass a dict with patterns
  38. # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
  39. # tp_rank = 0
  40. # tp_world_size = 1
  41. # the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
  42. # self.shape that means we automatically copy?
  43. # implementation b.
  44. # this version requires no additional data passed from the client
  45. # if the shapes already match it must be slices that were averaged - so we just hack around those
  46. if full_hp_param.shape == self.shape:
  47. tp_rank = 0
  48. tp_world_size = 1
  49. # special case for word_embeddings weights which get padded differently depending on TP degree.
  50. # the converter to universal currently strips the original padding completely so the saved
  51. # weight is padding-free and we just need to add new padding depending on the target TP
  52. # degree
  53. is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
  54. if is_vocab_tensor:
  55. # In the absence of data passed from the user wrt new padded vocab specific to tp degree
  56. # we can again derive that data by reverse engineering the target shapes like so:
  57. padded_target_vocab_size = self.shape[0] * tp_world_size
  58. assert padded_target_vocab_size >= full_hp_param.shape[0], \
  59. f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
  60. if padded_target_vocab_size > full_hp_param.shape[0]:
  61. padding_size = padded_target_vocab_size - full_hp_param.shape[0]
  62. full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
  63. full_param_numel = full_hp_param.numel()
  64. tp_slice_numel = self.numel()
  65. # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
  66. # print_rank_0(f'{full_hp_param[:10]=}', force=True)
  67. assert full_param_numel == tp_world_size * tp_slice_numel, \
  68. f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
  69. # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
  70. # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
  71. sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
  72. # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
  73. # special case is when a single parameter is effectively a container for multiple sub parameters
  74. # (more details at PARAM_N_SUB_PARAMS definition)
  75. chunk_dim = ckpt_dict.get(CAT_DIM, 0)
  76. n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
  77. if sub_param_shape:
  78. partition_dim = sub_param_shape.partition_dim
  79. sub_dim_sizes = sub_param_shape.shape[partition_dim]
  80. if not isinstance(sub_dim_sizes, tuple):
  81. sub_dim_sizes = (sub_dim_sizes, )
  82. partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
  83. full_hp_param = full_hp_param.view(partition_shape)
  84. offset = 0
  85. merged_chunks = []
  86. for sub_dim_size in sub_dim_sizes:
  87. sub_params_tp_slice = full_hp_param.narrow(partition_dim,
  88. offset, sub_dim_size).chunk(tp_world_size,
  89. dim=partition_dim)[tp_rank]
  90. merged_chunks.append(sub_params_tp_slice)
  91. offset += sub_dim_size
  92. tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)
  93. elif n_sub_params > 1:
  94. sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
  95. sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
  96. tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
  97. else:
  98. # this performs the opposite of cat when merging TP slices
  99. tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
  100. tp_hp_slice = tp_hp_slice.flatten()
  101. lp_frag_address = hp_mapping.lp_fragment_address
  102. tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
  103. # print(f"{key} SHAPE: {tp_hp_slice.shape=}")
  104. # print(f"{key} SHAPE: {dst_tensor.shape=}")
  105. # print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
  106. if key == FP32_WEIGHT_KEY:
  107. dst_tensor = hp_mapping.get_hp_fragment()
  108. assert dst_tensor.numel() == lp_frag_address.numel, \
  109. f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
  110. dst_tensor.data.copy_(tp_hp_fragment.data)
  111. else:
  112. assert tp_hp_fragment.numel() == lp_frag_address.numel, \
  113. f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'
  114. hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()
  115. return step
  116. def enable_universal_checkpoint(param_list):
  117. for param in param_list:
  118. param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)