universal_checkpoint.py 5.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import types
  7. from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)
  8. def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
  9. hp_mapping = self._hp_mapping
  10. optim_state_keys = hp_mapping.get_optim_state_keys()
  11. hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
  12. #print(f'{hp_keys=}')
  13. checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
  14. for file in checkpoint_files.values():
  15. assert os.path.isfile(file), f'{file} is not a valid file'
  16. for key in hp_keys:
  17. ckpt_file = checkpoint_files[key]
  18. ckpt_dict = torch.load(ckpt_file)
  19. full_hp_param = ckpt_dict[PARAM]
  20. # need to deal with slices that were averaged.
  21. # the opposite of averaging here becomes an exact copy of the first slice
  22. # I thought of 2 ways:
  23. # implementation a. find a way for a client to pass a dict with patterns
  24. # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
  25. # tp_rank = 0
  26. # tp_world_size = 1
  27. # the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
  28. # self.shape that means we automatically copy?
  29. # implementation b.
  30. # this version requires no additional data passed from the client
  31. # if the shapes already match it must be slices that were averaged - so we just hack around those
  32. if full_hp_param.shape == self.shape:
  33. tp_rank = 0
  34. tp_world_size = 1
  35. # special case for word_embeddings weights which get padded differently depending on TP degree.
  36. # the converter to universal currently strips the original padding completely so the saved
  37. # weight is padding-free and we just need to add new padding depending on the target TP
  38. # degree
  39. is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
  40. if is_vocab_tensor:
  41. # In the absence of data passed from the user wrt new padded vocab specific to tp degree
  42. # we can again derive that data by reverse engineering the target shapes like so:
  43. padded_target_vocab_size = self.shape[0] * tp_world_size
  44. assert padded_target_vocab_size >= full_hp_param.shape[0], \
  45. f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
  46. if padded_target_vocab_size > full_hp_param.shape[0]:
  47. padding_size = padded_target_vocab_size - full_hp_param.shape[0]
  48. full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
  49. full_param_numel = full_hp_param.numel()
  50. tp_slice_numel = self.numel()
  51. # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
  52. # print_rank_0(f'{full_hp_param[:10]=}', force=True)
  53. assert full_param_numel == tp_world_size * tp_slice_numel, \
  54. f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
  55. dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)
  56. # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
  57. # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
  58. # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
  59. # special case is when a single parameter is effectively a container for multiple sub parameters
  60. # (more details at PARAM_N_SUB_PARAMS definition)
  61. chunk_dim = ckpt_dict.get(CAT_DIM, 0)
  62. n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
  63. if n_sub_params > 1:
  64. sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
  65. sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
  66. tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
  67. else:
  68. # this performs the opposite of cat when merging TP slices
  69. tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
  70. tp_hp_slice = tp_hp_slice.flatten()
  71. lp_frag_address = hp_mapping.lp_fragment_address
  72. tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
  73. assert dst_tensor.numel() == lp_frag_address.numel, \
  74. f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
  75. # print(f"{key} SHAPE: {tp_hp_slice.shape=}")
  76. # print(f"{key} SHAPE: {dst_tensor.shape=}")
  77. # print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
  78. dst_tensor.data.copy_(tp_hp_fragment.data)
  79. def enable_universal_checkpoint(param_list):
  80. for param in param_list:
  81. param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)