state_dict_factory.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. import os
  6. import copy
  7. import collections
  8. import json
  9. from abc import ABC, abstractmethod
  10. from deepspeed.utils import logger
  11. from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
  12. from .weight_quantizer import WeightQuantization
  13. AUTO_MODULE_KEY = 'auto'
  14. class SDLoaderFactory:
  15. @staticmethod
  16. def get_sd_loader_json(json_file, checkpoint_engine):
  17. if isinstance(json_file, str):
  18. with open(json_file) as f:
  19. data = json.load(f)
  20. else:
  21. assert isinstance(json_file, dict)
  22. data = json_file
  23. sd_type = data['type']
  24. ckpt_list = data['checkpoints']
  25. version = data['version']
  26. ckpt_type = data.get('parallelization', 'pp')
  27. mp_size = data.get('mp_size', 0)
  28. if sd_type.lower() in ['bloom', 'ds_model']:
  29. return data
  30. return SDLoaderFactory.get_sd_loader(ckpt_list,
  31. checkpoint_engine,
  32. sd_type,
  33. version)
  34. @staticmethod
  35. def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None):
  36. if sd_type == 'Megatron':
  37. return MegatronSDLoader(ckpt_list, version, checkpoint_engine)
  38. else:
  39. assert False, '{} checkpoint type is not supported'.format(sd_type)
  40. class SDLoaderBase(ABC):
  41. def __init__(self, ckpt_list, version, checkpoint_engine):
  42. self.module_key = None
  43. self.ckpt_list = ckpt_list
  44. self.version = version
  45. self.checkpoint_engine = TorchCheckpointEngine(
  46. ) if checkpoint_engine is None else checkpoint_engine
  47. self.check_ckpt_list()
  48. def load(self,
  49. mp_world_size,
  50. mp_rank,
  51. module_key=AUTO_MODULE_KEY,
  52. is_pipe_parallel=False,
  53. quantize=False,
  54. quantize_bits=8,
  55. quantize_groups=64,
  56. mlp_extra_grouping=True):
  57. self.module_key = module_key
  58. num_ckpt = len(self.ckpt_list)
  59. idx = mp_rank * num_ckpt // mp_world_size
  60. """ We have multiple cases to handle here for both training and inference:
  61. 1. PipeModule loading mp_rank_*.pt files, is_pipe_parallel=True, module_key is not None
  62. a. if no mp_size/pp_size resizing occurs, for both training & inference, loading
  63. the mp_rank related checkpoint directly.
  64. b. if has mp_size/pp_size resizing, only Megatron model inference is supported,
  65. in this case each mp_rank_*.pt have same content, we will load the first checkpoint
  66. file (idx=0), to avoid idx exceeding file list boundary.
  67. 2. PipeModule loading layer_*.pt files, is_pipe_parallel=True, module_key is None
  68. a. if no mp_size resizing occurs, for both training & inference, loading
  69. the mp_rank related checkpoint directly.
  70. b. if has mp_size resizing, only Megatron model inference is supported,
  71. checkpoint file(s) will be merged/split according to mp_rank, mp_world_size and
  72. checkpoint file list.
  73. 3. Non-PipeModule loading mp_rank_*.pt files, is_pipe_parallel=False
  74. Same with case (2).
  75. """
  76. if is_pipe_parallel and module_key is not None and mp_world_size != num_ckpt:
  77. mp_world_size = num_ckpt
  78. idx = 0
  79. load_path = self.ckpt_list[idx]
  80. merge_count = 1
  81. if num_ckpt == mp_world_size:
  82. assert os.path.exists(load_path)
  83. #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}')
  84. sd = self.checkpoint_engine.load(load_path, map_location=lambda storage, \
  85. loc: storage)
  86. if quantize:
  87. quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
  88. mp_size=mp_world_size)
  89. sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits, quantize_groups)
  90. self.set_module(sd, sd_module)
  91. else:
  92. all_scales = None
  93. elif num_ckpt > mp_world_size:
  94. sd, all_scales, merge_count = self.merge_state_dict(mp_world_size, mp_rank, quantize, \
  95. quantize_bits, quantize_groups, mlp_extra_grouping)
  96. else:
  97. sd, all_scales = self.split_state_dict(mp_world_size, mp_rank, quantize, quantize_bits, \
  98. quantize_groups, mlp_extra_grouping)
  99. return load_path, sd, (all_scales, merge_count)
  100. def get_merge_state_dicts(self, mp_world_size, mp_rank):
  101. num_ckpt = len(self.ckpt_list)
  102. assert num_ckpt % mp_world_size == 0, 'Invalid checkpoints and world size for sd merge'
  103. num_to_merge = num_ckpt // mp_world_size
  104. ckpt_list = [
  105. self.ckpt_list[i] for i in range(num_to_merge * mp_rank,
  106. num_to_merge * (mp_rank + 1))
  107. ]
  108. logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}")
  109. sd_list = [
  110. self.checkpoint_engine.load(ckpt,
  111. map_location=lambda storage,
  112. loc: storage) for ckpt in ckpt_list
  113. ]
  114. return sd_list
  115. def get_split_state_dict(self, mp_world_size, mp_rank):
  116. num_ckpt = len(self.ckpt_list)
  117. assert mp_world_size % num_ckpt == 0, 'Invalid checkpoints and world size for sd split'
  118. num_to_split = mp_world_size // num_ckpt
  119. ckpt_index = mp_rank // num_to_split
  120. ckpt_offset = mp_rank % num_to_split
  121. logger.info(
  122. f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}"
  123. )
  124. sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index],
  125. map_location=lambda storage,
  126. loc: storage)
  127. return sd, num_to_split, ckpt_offset
  128. def _choose_module_key(self, sd):
  129. assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
  130. assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
  131. if 'module' in sd:
  132. return 'module'
  133. elif 'model' in sd:
  134. return 'model'
  135. def get_module(self, sd):
  136. if self.module_key is None:
  137. return sd
  138. elif self.module_key == AUTO_MODULE_KEY:
  139. return sd[self._choose_module_key(sd)]
  140. else:
  141. return sd[self.module_key]
  142. def set_module(self, sd, module):
  143. if self.module_key is None:
  144. sd = module
  145. elif self.module_key == AUTO_MODULE_KEY:
  146. sd[self._choose_module_key(sd)] = module
  147. else:
  148. sd[self.module_key] = module
  149. return sd
  150. def check_ckpt_list(self):
  151. #logger.info(f'checkpoint file list: {self.ckpt_list}')
  152. assert len(self.ckpt_list) > 0
  153. sd = self.checkpoint_engine.load(self.ckpt_list[0],
  154. map_location=lambda storage,
  155. loc: storage)
  156. # check checkpoint count is same with saved mp_world_size
  157. if 'mp_world_size' in sd.keys():
  158. assert len(self.ckpt_list) == sd['mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"
  159. @abstractmethod
  160. def merge_state_dict(self,
  161. mp_world_size,
  162. mp_rank,
  163. quantize,
  164. quantize_bits,
  165. groups,
  166. mlp_extra_grouping):
  167. pass
  168. @abstractmethod
  169. def split_state_dict(self,
  170. mp_world_size,
  171. mp_rank,
  172. quantize,
  173. quantize_bits,
  174. groups,
  175. mlp_extra_grouping):
  176. pass
  177. @abstractmethod
  178. def sanity_check(self, ckpt_file_name):
  179. pass
  180. class MegatronSDLoader(SDLoaderBase):
  181. def __init__(self, ckpt_list, version, checkpoint_engine):
  182. super().__init__(ckpt_list, version, checkpoint_engine)
  183. """
  184. ## Q/K/V data need special processing
  185. key: transformer.layers.0.attention.query_key_value.weight, shape: torch.Size([3192, 4256])
  186. key: transformer.layers.0.attention.query_key_value.bias, shape: torch.Size([3192])
  187. ## merge or split on axis=0
  188. key: word_embeddings.weight, shape: torch.Size([12672, 4256])
  189. key: transformer.layers.0.mlp.dense_h_to_4h.bias, shape: torch.Size([4256])
  190. key: transformer.layers.0.mlp.dense_h_to_4h.weight, shape: torch.Size([4256, 4256])
  191. ## merge or split on axis=1
  192. key: transformer.layers.0.attention.dense.weight, shape: torch.Size([4256, 1064])
  193. key: transformer.layers.0.mlp.dense_4h_to_h.weight, shape: torch.Size([4256, 4256])
  194. ## no change required
  195. key: transformer.layers.0.mlp.dense_4h_to_h.bias, shape: torch.Size([4256])
  196. key: transformer.final_layernorm.weight, shape: torch.Size([4256])
  197. key: transformer.final_layernorm.bias, shape: torch.Size([4256])
  198. key: transformer.layers.0.attention.dense.bias, shape: torch.Size([4256])
  199. key: transformer.layers.0.post_attention_layernorm.weight, shape: torch.Size([4256])
  200. key: transformer.layers.0.post_attention_layernorm.bias, shape: torch.Size([4256])
  201. key: transformer.layers.0.input_layernorm.weight, shape: torch.Size([4256])
  202. key: transformer.layers.0.input_layernorm.bias, shape: torch.Size([4256])
  203. key: position_embeddings.weight, shape: torch.Size([1024, 4256])
  204. """
  205. def merge_query_key_value(self, param_list, ckpt_ver):
  206. """
  207. Up to now we found 3 Q/K/V parameter formats in different Megatron checkpoint versions:
  208. 1. version 0, there is no version information saved in checkpoint.
  209. format: [(3 * np * hn), h]
  210. 2. version 1.0
  211. format: [(np * hn * 3), h]
  212. 3. version 2.0
  213. format: [(np * 3 * hn), h]
  214. h: hidden size
  215. n: number of attention heads
  216. p: number of model parallel partitions
  217. np: n/p
  218. hn: h/n
  219. """
  220. new_qkv = None
  221. if ckpt_ver == 0:
  222. # [(3 * np * hn), h]
  223. assert param_list[0].shape[0] % 3 == 0
  224. size_qkv = param_list[0].shape[0] // 3
  225. split_tensors = [torch.split(param, size_qkv, dim=0) for param in param_list]
  226. tensors = []
  227. for i in range(3):
  228. tensor_tuple = [t[i] for t in split_tensors]
  229. tensors.append(torch.cat(tensor_tuple, axis=0))
  230. new_qkv = torch.cat(tensors, axis=0)
  231. elif ckpt_ver == 1.0 or ckpt_ver == 2.0:
  232. # [(np * hn * 3), h] or [(np * 3 * hn), h]
  233. new_qkv = torch.cat(param_list, axis=0)
  234. else:
  235. assert False, f'checkpoint version: {ckpt_ver} is not supported'
  236. return new_qkv
  237. def split_query_key_value(self, param, num_to_split, offset, ckpt_ver):
  238. """
  239. Up to now we found 3 Q/K/V parameter formats in different Megatron checkpoint versions:
  240. 1. version 0, there is no version information saved in checkpoint.
  241. format: [(3 * np * hn), h]
  242. 2. version 1.0
  243. format: [(np * hn * 3), h]
  244. 3. version 2.0
  245. format: [(np * 3 * hn), h]
  246. h: hidden size
  247. n: number of attention heads
  248. p: number of model parallel partitions
  249. np: n/p
  250. hn: h/n
  251. """
  252. new_qkv = None
  253. if ckpt_ver == 0:
  254. # [(3 * np * hn), h]
  255. assert param.shape[0] % 3 == 0
  256. size_qkv = param.shape[0] // 3
  257. split_tensors = torch.split(param, size_qkv, dim=0)
  258. assert split_tensors[0].shape[0] % num_to_split == 0
  259. split_size = split_tensors[0].shape[0] // num_to_split
  260. tensors = []
  261. for i in range(3):
  262. tensors.append(torch.split(split_tensors[i], split_size, dim=0)[offset])
  263. new_qkv = torch.cat(tensors, axis=0)
  264. elif ckpt_ver == 1.0 or ckpt_ver == 2.0:
  265. # [(np * hn * 3), h] or [(np * 3 * hn), h]
  266. assert param.shape[0] % num_to_split == 0
  267. size_qkv = param.shape[0] // num_to_split
  268. split_tensors = torch.split(param, size_qkv, dim=0)
  269. new_qkv = split_tensors[offset]
  270. else:
  271. assert False, f'checkpoint version: {ckpt_ver} is not supported'
  272. return new_qkv
  273. def merge_state_dict(self,
  274. mp_world_size,
  275. mp_rank,
  276. quantize=False,
  277. quantize_bits=8,
  278. groups=64,
  279. mlp_extra_grouping=True):
  280. self.sanity_check(self.ckpt_list[0])
  281. sd_list = self.get_merge_state_dicts(mp_world_size, mp_rank)
  282. ds_sd = copy.deepcopy(sd_list[0])
  283. new_client_sd = collections.OrderedDict()
  284. client_sd_list = [self.get_module(sd) for sd in sd_list]
  285. keys = client_sd_list[0].keys()
  286. ckpt_ver = self.get_checkpoint_version(ds_sd)
  287. logger.info(f"checkpoint version: {ckpt_ver}")
  288. if quantize:
  289. quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
  290. mp_size=mp_world_size)
  291. for key in keys:
  292. value_list = [sd[key] for sd in client_sd_list]
  293. if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
  294. if quantize:
  295. value_list = quantizer.Quantize(value_list,
  296. quantize_bits,
  297. groups,
  298. key=key,
  299. merge_dim=1)
  300. new_client_sd[key] = torch.cat(value_list, axis=1)
  301. elif "attention.query_key_value" in key:
  302. if quantize and "attention.query_key_value.weight" in key:
  303. value_list = quantizer.Quantize(value_list,
  304. quantize_bits,
  305. groups,
  306. key=key)
  307. new_client_sd[key] = torch.cat(value_list, axis=0)
  308. else:
  309. if quantize:
  310. new_client_sd[key] = torch.cat(value_list, axis=0)
  311. else:
  312. new_client_sd[key] = self.merge_query_key_value(
  313. value_list,
  314. ckpt_ver)
  315. elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key:
  316. if quantize and "mlp.dense_h_to_4h.weight" in key:
  317. value_list = quantizer.Quantize(value_list,
  318. quantize_bits,
  319. groups,
  320. key=key)
  321. new_client_sd[key] = torch.cat(value_list, axis=0)
  322. else:
  323. new_client_sd[key] = value_list[0]
  324. if quantize:
  325. all_scales = quantizer.merge_scales()
  326. ds_sd = self.set_module(ds_sd, new_client_sd)
  327. return ds_sd, (all_scales if quantize else None), len(client_sd_list)
  328. def split_state_dict(self,
  329. mp_world_size,
  330. mp_rank,
  331. quantize=False,
  332. quantize_bits=8,
  333. groups=64,
  334. mlp_extra_grouping=True):
  335. #self.sanity_check(self.ckpt_list[0])
  336. sd, num_to_split, ckpt_offset = self.get_split_state_dict(mp_world_size, mp_rank)
  337. ds_sd = copy.deepcopy(sd)
  338. new_client_sd = collections.OrderedDict()
  339. client_sd = self.get_module(sd)
  340. ckpt_ver = self.get_checkpoint_version(ds_sd)
  341. logger.info(f"checkpoint version: {ckpt_ver}")
  342. if quantize:
  343. quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping,
  344. mp_size=mp_world_size)
  345. for key in client_sd.keys():
  346. value = client_sd[key]
  347. if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
  348. assert value.shape[1] % num_to_split == 0
  349. split_size = value.shape[1] // num_to_split
  350. if quantize:
  351. q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
  352. value = q_vals[0]
  353. new_client_sd[key] = torch.split(value, split_size, dim=1)[ckpt_offset]
  354. elif "attention.query_key_value" in key:
  355. if quantize and "attention.query_key_value.weight" in key:
  356. q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
  357. value = q_vals[0]
  358. new_client_sd[key] = self.split_query_key_value(
  359. value,
  360. num_to_split,
  361. ckpt_offset,
  362. ckpt_ver)
  363. elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key:
  364. assert value.shape[0] % num_to_split == 0
  365. split_size = value.shape[0] // num_to_split
  366. if quantize and "mlp.dense_h_to_4h.weight" in key:
  367. q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
  368. value = q_vals[0]
  369. new_client_sd[key] = torch.split(value, split_size, dim=0)[ckpt_offset]
  370. else:
  371. new_client_sd[key] = value
  372. if quantize:
  373. all_scales = quantizer.merge_scales_split(num_to_split)
  374. ds_sd = self.set_module(ds_sd, new_client_sd)
  375. return ds_sd, (all_scales if quantize else None)
  376. def sanity_check(self, ckpt_file_name):
  377. keys_to_check = [
  378. "attention.dense.weight",
  379. "mlp.dense_4h_to_h.weight",
  380. "attention.query_key_value",
  381. "mlp.dense_h_to_4h.weight",
  382. "mlp.dense_h_to_4h.bias"
  383. ]
  384. sd = self.checkpoint_engine.load(ckpt_file_name,
  385. map_location=lambda storage,
  386. loc: storage)
  387. # partial_key is a sub-string of one key in the sd
  388. def check_key_exist(partial_key, sd):
  389. keys = sd.keys()
  390. found = False
  391. for k in keys:
  392. if partial_key in k:
  393. found = True
  394. break
  395. return found
  396. for key in keys_to_check:
  397. assert check_key_exist(key, self.get_module(sd)), f'key: {key} is not found in the checkpoint {ckpt_file_name}'
  398. def get_checkpoint_version(self, state_dict):
  399. # Use 0 if version info doesn't exist
  400. return self.version if self.version is not None else state_dict.get(
  401. 'checkpoint_version',
  402. 0)