state_dict_factory.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import os
  6. import copy
  7. import collections
  8. import json
  9. from abc import ABC, abstractmethod
  10. from deepspeed.utils import logger
  11. from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
  12. from .weight_quantizer import WeightQuantization
  13. AUTO_MODULE_KEY = 'auto'
  14. class SDLoaderFactory:
  15. @staticmethod
  16. def get_sd_loader_json(json_file, checkpoint_engine):
  17. if isinstance(json_file, str):
  18. with open(json_file) as f:
  19. data = json.load(f)
  20. else:
  21. assert isinstance(json_file, dict)
  22. data = json_file
  23. sd_type = data['type']
  24. ckpt_list = data['checkpoints']
  25. version = data['version']
  26. ckpt_type = data.get('parallelization', 'pp')
  27. mp_size = data.get('mp_size', 0)
  28. if sd_type.lower() in ['bloom', 'ds_model']:
  29. return data
  30. return SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine, sd_type, version)
  31. @staticmethod
  32. def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None):
  33. if sd_type == 'Megatron':
  34. return MegatronSDLoader(ckpt_list, version, checkpoint_engine)
  35. else:
  36. assert False, '{} checkpoint type is not supported'.format(sd_type)
  37. class SDLoaderBase(ABC):
  38. def __init__(self, ckpt_list, version, checkpoint_engine):
  39. self.module_key = None
  40. self.ckpt_list = ckpt_list
  41. self.version = version
  42. self.checkpoint_engine = TorchCheckpointEngine() if checkpoint_engine is None else checkpoint_engine
  43. self.check_ckpt_list()
  44. def load(self,
  45. mp_world_size,
  46. mp_rank,
  47. module_key=AUTO_MODULE_KEY,
  48. is_pipe_parallel=False,
  49. quantize=False,
  50. quantize_bits=8,
  51. quantize_groups=64,
  52. mlp_extra_grouping=True):
  53. self.module_key = module_key
  54. num_ckpt = len(self.ckpt_list)
  55. idx = mp_rank * num_ckpt // mp_world_size
  56. """ We have multiple cases to handle here for both training and inference:
  57. 1. PipeModule loading mp_rank_*.pt files, is_pipe_parallel=True, module_key is not None
  58. a. if no mp_size/pp_size resizing occurs, for both training & inference, loading
  59. the mp_rank related checkpoint directly.
  60. b. if has mp_size/pp_size resizing, only Megatron model inference is supported,
  61. in this case each mp_rank_*.pt have same content, we will load the first checkpoint
  62. file (idx=0), to avoid idx exceeding file list boundary.
  63. 2. PipeModule loading layer_*.pt files, is_pipe_parallel=True, module_key is None
  64. a. if no mp_size resizing occurs, for both training & inference, loading
  65. the mp_rank related checkpoint directly.
  66. b. if has mp_size resizing, only Megatron model inference is supported,
  67. checkpoint file(s) will be merged/split according to mp_rank, mp_world_size and
  68. checkpoint file list.
  69. 3. Non-PipeModule loading mp_rank_*.pt files, is_pipe_parallel=False
  70. Same with case (2).
  71. """
  72. if is_pipe_parallel and module_key is not None and mp_world_size != num_ckpt:
  73. mp_world_size = num_ckpt
  74. idx = 0
  75. load_path = self.ckpt_list[idx]
  76. merge_count = 1
  77. if num_ckpt == mp_world_size:
  78. assert os.path.exists(load_path)
  79. #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}')
  80. sd = self.checkpoint_engine.load(load_path, map_location=lambda storage, \
  81. loc: storage)
  82. if quantize:
  83. quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
  84. sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits,
  85. quantize_groups)
  86. self.set_module(sd, sd_module)
  87. else:
  88. all_scales = None
  89. elif num_ckpt > mp_world_size:
  90. sd, all_scales, merge_count = self.merge_state_dict(mp_world_size, mp_rank, quantize, \
  91. quantize_bits, quantize_groups, mlp_extra_grouping)
  92. else:
  93. sd, all_scales = self.split_state_dict(mp_world_size, mp_rank, quantize, quantize_bits, \
  94. quantize_groups, mlp_extra_grouping)
  95. return load_path, sd, (all_scales, merge_count)
  96. def get_merge_state_dicts(self, mp_world_size, mp_rank):
  97. num_ckpt = len(self.ckpt_list)
  98. assert num_ckpt % mp_world_size == 0, 'Invalid checkpoints and world size for sd merge'
  99. num_to_merge = num_ckpt // mp_world_size
  100. ckpt_list = [self.ckpt_list[i] for i in range(num_to_merge * mp_rank, num_to_merge * (mp_rank + 1))]
  101. logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}")
  102. sd_list = [self.checkpoint_engine.load(ckpt, map_location=lambda storage, loc: storage) for ckpt in ckpt_list]
  103. return sd_list
  104. def get_split_state_dict(self, mp_world_size, mp_rank):
  105. num_ckpt = len(self.ckpt_list)
  106. assert mp_world_size % num_ckpt == 0, 'Invalid checkpoints and world size for sd split'
  107. num_to_split = mp_world_size // num_ckpt
  108. ckpt_index = mp_rank // num_to_split
  109. ckpt_offset = mp_rank % num_to_split
  110. logger.info(f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}")
  111. sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index], map_location=lambda storage, loc: storage)
  112. return sd, num_to_split, ckpt_offset
  113. def _choose_module_key(self, sd):
  114. assert not ('module' in sd
  115. and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
  116. assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
  117. if 'module' in sd:
  118. return 'module'
  119. elif 'model' in sd:
  120. return 'model'
  121. def get_module(self, sd):
  122. if self.module_key is None:
  123. return sd
  124. elif self.module_key == AUTO_MODULE_KEY:
  125. return sd[self._choose_module_key(sd)]
  126. else:
  127. return sd[self.module_key]
  128. def set_module(self, sd, module):
  129. if self.module_key is None:
  130. sd = module
  131. elif self.module_key == AUTO_MODULE_KEY:
  132. sd[self._choose_module_key(sd)] = module
  133. else:
  134. sd[self.module_key] = module
  135. return sd
  136. def check_ckpt_list(self):
  137. #logger.info(f'checkpoint file list: {self.ckpt_list}')
  138. assert len(self.ckpt_list) > 0
  139. sd = self.checkpoint_engine.load(self.ckpt_list[0], map_location=lambda storage, loc: storage)
  140. # check checkpoint count is same with saved mp_world_size
  141. if 'mp_world_size' in sd.keys():
  142. assert len(self.ckpt_list) == sd[
  143. 'mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"
  144. @abstractmethod
  145. def merge_state_dict(self, mp_world_size, mp_rank, quantize, quantize_bits, groups, mlp_extra_grouping):
  146. pass
  147. @abstractmethod
  148. def split_state_dict(self, mp_world_size, mp_rank, quantize, quantize_bits, groups, mlp_extra_grouping):
  149. pass
  150. @abstractmethod
  151. def sanity_check(self, ckpt_file_name):
  152. pass
  153. class MegatronSDLoader(SDLoaderBase):
  154. def __init__(self, ckpt_list, version, checkpoint_engine):
  155. super().__init__(ckpt_list, version, checkpoint_engine)
  156. """
  157. ## Q/K/V data need special processing
  158. key: transformer.layers.0.attention.query_key_value.weight, shape: torch.Size([3192, 4256])
  159. key: transformer.layers.0.attention.query_key_value.bias, shape: torch.Size([3192])
  160. ## merge or split on axis=0
  161. key: word_embeddings.weight, shape: torch.Size([12672, 4256])
  162. key: transformer.layers.0.mlp.dense_h_to_4h.bias, shape: torch.Size([4256])
  163. key: transformer.layers.0.mlp.dense_h_to_4h.weight, shape: torch.Size([4256, 4256])
  164. ## merge or split on axis=1
  165. key: transformer.layers.0.attention.dense.weight, shape: torch.Size([4256, 1064])
  166. key: transformer.layers.0.mlp.dense_4h_to_h.weight, shape: torch.Size([4256, 4256])
  167. ## no change required
  168. key: transformer.layers.0.mlp.dense_4h_to_h.bias, shape: torch.Size([4256])
  169. key: transformer.final_layernorm.weight, shape: torch.Size([4256])
  170. key: transformer.final_layernorm.bias, shape: torch.Size([4256])
  171. key: transformer.layers.0.attention.dense.bias, shape: torch.Size([4256])
  172. key: transformer.layers.0.post_attention_layernorm.weight, shape: torch.Size([4256])
  173. key: transformer.layers.0.post_attention_layernorm.bias, shape: torch.Size([4256])
  174. key: transformer.layers.0.input_layernorm.weight, shape: torch.Size([4256])
  175. key: transformer.layers.0.input_layernorm.bias, shape: torch.Size([4256])
  176. key: position_embeddings.weight, shape: torch.Size([1024, 4256])
  177. """
  178. def merge_query_key_value(self, param_list, ckpt_ver):
  179. """
  180. Up to now we found 3 Q/K/V parameter formats in different Megatron checkpoint versions:
  181. 1. version 0, there is no version information saved in checkpoint.
  182. format: [(3 * np * hn), h]
  183. 2. version 1.0
  184. format: [(np * hn * 3), h]
  185. 3. version 2.0
  186. format: [(np * 3 * hn), h]
  187. h: hidden size
  188. n: number of attention heads
  189. p: number of model parallel partitions
  190. np: n/p
  191. hn: h/n
  192. """
  193. new_qkv = None
  194. if ckpt_ver == 0:
  195. # [(3 * np * hn), h]
  196. assert param_list[0].shape[0] % 3 == 0
  197. size_qkv = param_list[0].shape[0] // 3
  198. split_tensors = [torch.split(param, size_qkv, dim=0) for param in param_list]
  199. tensors = []
  200. for i in range(3):
  201. tensor_tuple = [t[i] for t in split_tensors]
  202. tensors.append(torch.cat(tensor_tuple, axis=0))
  203. new_qkv = torch.cat(tensors, axis=0)
  204. elif ckpt_ver == 1.0 or ckpt_ver == 2.0:
  205. # [(np * hn * 3), h] or [(np * 3 * hn), h]
  206. new_qkv = torch.cat(param_list, axis=0)
  207. else:
  208. assert False, f'checkpoint version: {ckpt_ver} is not supported'
  209. return new_qkv
  210. def split_query_key_value(self, param, num_to_split, offset, ckpt_ver):
  211. """
  212. Up to now we found 3 Q/K/V parameter formats in different Megatron checkpoint versions:
  213. 1. version 0, there is no version information saved in checkpoint.
  214. format: [(3 * np * hn), h]
  215. 2. version 1.0
  216. format: [(np * hn * 3), h]
  217. 3. version 2.0
  218. format: [(np * 3 * hn), h]
  219. h: hidden size
  220. n: number of attention heads
  221. p: number of model parallel partitions
  222. np: n/p
  223. hn: h/n
  224. """
  225. new_qkv = None
  226. if ckpt_ver == 0:
  227. # [(3 * np * hn), h]
  228. assert param.shape[0] % 3 == 0
  229. size_qkv = param.shape[0] // 3
  230. split_tensors = torch.split(param, size_qkv, dim=0)
  231. assert split_tensors[0].shape[0] % num_to_split == 0
  232. split_size = split_tensors[0].shape[0] // num_to_split
  233. tensors = []
  234. for i in range(3):
  235. tensors.append(torch.split(split_tensors[i], split_size, dim=0)[offset])
  236. new_qkv = torch.cat(tensors, axis=0)
  237. elif ckpt_ver == 1.0 or ckpt_ver == 2.0:
  238. # [(np * hn * 3), h] or [(np * 3 * hn), h]
  239. assert param.shape[0] % num_to_split == 0
  240. size_qkv = param.shape[0] // num_to_split
  241. split_tensors = torch.split(param, size_qkv, dim=0)
  242. new_qkv = split_tensors[offset]
  243. else:
  244. assert False, f'checkpoint version: {ckpt_ver} is not supported'
  245. return new_qkv
  246. def merge_state_dict(self,
  247. mp_world_size,
  248. mp_rank,
  249. quantize=False,
  250. quantize_bits=8,
  251. groups=64,
  252. mlp_extra_grouping=True):
  253. self.sanity_check(self.ckpt_list[0])
  254. sd_list = self.get_merge_state_dicts(mp_world_size, mp_rank)
  255. ds_sd = copy.deepcopy(sd_list[0])
  256. new_client_sd = collections.OrderedDict()
  257. client_sd_list = [self.get_module(sd) for sd in sd_list]
  258. keys = client_sd_list[0].keys()
  259. ckpt_ver = self.get_checkpoint_version(ds_sd)
  260. logger.info(f"checkpoint version: {ckpt_ver}")
  261. if quantize:
  262. quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
  263. for key in keys:
  264. value_list = [sd[key] for sd in client_sd_list]
  265. if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
  266. if quantize:
  267. value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key, merge_dim=1)
  268. new_client_sd[key] = torch.cat(value_list, axis=1)
  269. elif "attention.query_key_value" in key:
  270. if quantize and "attention.query_key_value.weight" in key:
  271. value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key)
  272. new_client_sd[key] = torch.cat(value_list, axis=0)
  273. else:
  274. if quantize:
  275. new_client_sd[key] = torch.cat(value_list, axis=0)
  276. else:
  277. new_client_sd[key] = self.merge_query_key_value(value_list, ckpt_ver)
  278. elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key:
  279. if quantize and "mlp.dense_h_to_4h.weight" in key:
  280. value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key)
  281. new_client_sd[key] = torch.cat(value_list, axis=0)
  282. else:
  283. new_client_sd[key] = value_list[0]
  284. if quantize:
  285. all_scales = quantizer.merge_scales()
  286. ds_sd = self.set_module(ds_sd, new_client_sd)
  287. return ds_sd, (all_scales if quantize else None), len(client_sd_list)
  288. def split_state_dict(self,
  289. mp_world_size,
  290. mp_rank,
  291. quantize=False,
  292. quantize_bits=8,
  293. groups=64,
  294. mlp_extra_grouping=True):
  295. #self.sanity_check(self.ckpt_list[0])
  296. sd, num_to_split, ckpt_offset = self.get_split_state_dict(mp_world_size, mp_rank)
  297. ds_sd = copy.deepcopy(sd)
  298. new_client_sd = collections.OrderedDict()
  299. client_sd = self.get_module(sd)
  300. ckpt_ver = self.get_checkpoint_version(ds_sd)
  301. logger.info(f"checkpoint version: {ckpt_ver}")
  302. if quantize:
  303. quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
  304. for key in client_sd.keys():
  305. value = client_sd[key]
  306. if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
  307. assert value.shape[1] % num_to_split == 0
  308. split_size = value.shape[1] // num_to_split
  309. if quantize:
  310. q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
  311. value = q_vals[0]
  312. new_client_sd[key] = torch.split(value, split_size, dim=1)[ckpt_offset]
  313. elif "attention.query_key_value" in key:
  314. if quantize and "attention.query_key_value.weight" in key:
  315. q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
  316. value = q_vals[0]
  317. new_client_sd[key] = self.split_query_key_value(value, num_to_split, ckpt_offset, ckpt_ver)
  318. elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key:
  319. assert value.shape[0] % num_to_split == 0
  320. split_size = value.shape[0] // num_to_split
  321. if quantize and "mlp.dense_h_to_4h.weight" in key:
  322. q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
  323. value = q_vals[0]
  324. new_client_sd[key] = torch.split(value, split_size, dim=0)[ckpt_offset]
  325. else:
  326. new_client_sd[key] = value
  327. if quantize:
  328. all_scales = quantizer.merge_scales_split(num_to_split)
  329. ds_sd = self.set_module(ds_sd, new_client_sd)
  330. return ds_sd, (all_scales if quantize else None)
  331. def sanity_check(self, ckpt_file_name):
  332. keys_to_check = [
  333. "attention.dense.weight", "mlp.dense_4h_to_h.weight", "attention.query_key_value",
  334. "mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias"
  335. ]
  336. sd = self.checkpoint_engine.load(ckpt_file_name, map_location=lambda storage, loc: storage)
  337. # partial_key is a sub-string of one key in the sd
  338. def check_key_exist(partial_key, sd):
  339. keys = sd.keys()
  340. found = False
  341. for k in keys:
  342. if partial_key in k:
  343. found = True
  344. break
  345. return found
  346. for key in keys_to_check:
  347. assert check_key_exist(key,
  348. self.get_module(sd)), f'key: {key} is not found in the checkpoint {ckpt_file_name}'
  349. def get_checkpoint_version(self, state_dict):
  350. # Use 0 if version info doesn't exist
  351. return self.version if self.version is not None else state_dict.get('checkpoint_version', 0)