zero_to_fp32.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. #!/usr/bin/env python
  2. '''Copyright The Microsoft DeepSpeed Team'''
  3. # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
  4. # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
  5. # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
  6. # application.
  7. #
  8. # example: python zero_to_fp32.py . pytorch_model.bin
  9. import argparse
  10. import torch
  11. import glob
  12. import math
  13. import os
  14. import re
  15. from collections import OrderedDict
  16. # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
  17. # DeepSpeed data structures it has to be available in the current python environment.
  18. from deepspeed.utils import logger
  19. from deepspeed.checkpoint.constants import (DS_VERSION,
  20. OPTIMIZER_STATE_DICT,
  21. SINGLE_PARTITION_OF_FP32_GROUPS,
  22. FP32_FLAT_GROUPS,
  23. ZERO_STAGE,
  24. PARTITION_COUNT,
  25. PARAM_SHAPES,
  26. BUFFER_NAMES)
  27. debug = 0
  28. # load to cpu
  29. device = torch.device('cpu')
  30. def atoi(text):
  31. return int(text) if text.isdigit() else text
  32. def natural_keys(text):
  33. '''
  34. alist.sort(key=natural_keys) sorts in human order
  35. http://nedbatchelder.com/blog/200712/human_sorting.html
  36. (See Toothy's implementation in the comments)
  37. '''
  38. return [atoi(c) for c in re.split(r'(\d+)', text)]
  39. def get_model_state_file(checkpoint_dir, zero_stage):
  40. if not os.path.isdir(checkpoint_dir):
  41. raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
  42. # there should be only one file
  43. if zero_stage == 2:
  44. file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
  45. elif zero_stage == 3:
  46. file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
  47. if not os.path.exists(file):
  48. raise FileNotFoundError(f"can't find model states file at '{file}'")
  49. return file
  50. def get_optim_files(checkpoint_dir):
  51. # XXX: need to test that this simple glob rule works for multi-node setup too
  52. optim_files = sorted(glob.glob(os.path.join(checkpoint_dir,
  53. "*_optim_states.pt")),
  54. key=natural_keys)
  55. if len(optim_files) == 0:
  56. raise FileNotFoundError(
  57. f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
  58. return optim_files
  59. def parse_model_state(file):
  60. state_dict = torch.load(file, map_location=device)
  61. if BUFFER_NAMES not in state_dict:
  62. raise ValueError(f"{file} is not a model state checkpoint")
  63. buffer_names = state_dict[BUFFER_NAMES]
  64. if debug:
  65. print("Found buffers:", buffer_names)
  66. # recover just the buffers while restoring them to fp32 if they were saved in fp16
  67. buffers = {
  68. k: v.float()
  69. for k,
  70. v in state_dict["module"].items() if k in buffer_names
  71. }
  72. param_shapes = state_dict[PARAM_SHAPES]
  73. ds_version = state_dict.get(DS_VERSION, None)
  74. return buffers, param_shapes, ds_version
  75. def parse_optim_states(files, ds_checkpoint_dir):
  76. total_files = len(files)
  77. state_dicts = []
  78. for f in files:
  79. state_dicts.append(torch.load(f, map_location=device))
  80. if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
  81. raise ValueError(f"{files[0]} is not a zero checkpoint")
  82. zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
  83. world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
  84. # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
  85. # parameters can be different from data parallelism for non-expert parameters. So we can just
  86. # use the max of the partition_count to get the dp world_size.
  87. if type(world_size) is list:
  88. world_size = max(world_size)
  89. if world_size != total_files:
  90. raise ValueError(
  91. f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
  92. "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
  93. )
  94. # the groups are named differently in each stage
  95. if zero_stage == 2:
  96. fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
  97. elif zero_stage == 3:
  98. fp32_groups_key = FP32_FLAT_GROUPS
  99. else:
  100. raise ValueError(f"unknown zero stage {zero_stage}")
  101. if zero_stage == 2:
  102. fp32_flat_groups = [
  103. state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key]
  104. for i in range(len(state_dicts))
  105. ]
  106. elif zero_stage == 3:
  107. # if there is more than one param group, there will be multiple flattened tensors - one
  108. # flattened tensor per group - for simplicity merge them into a single tensor
  109. #
  110. # XXX: could make the script more memory efficient for when there are multiple groups - it
  111. # will require matching the sub-lists of param_shapes for each param group flattened tensor
  112. fp32_flat_groups = [
  113. torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key],
  114. 0) for i in range(len(state_dicts))
  115. ]
  116. return zero_stage, world_size, fp32_flat_groups
  117. def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
  118. """
  119. Returns fp32 state_dict reconstructed from ds checkpoint
  120. Args:
  121. - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
  122. """
  123. print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
  124. optim_files = get_optim_files(ds_checkpoint_dir)
  125. zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
  126. print(
  127. f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
  128. model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
  129. buffers, param_shapes, ds_version = parse_model_state(model_file)
  130. print(f'Parsing checkpoint created by deepspeed=={ds_version}')
  131. if zero_stage == 2:
  132. return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
  133. param_shapes,
  134. fp32_flat_groups,
  135. buffers)
  136. elif zero_stage == 3:
  137. return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
  138. param_shapes,
  139. fp32_flat_groups,
  140. buffers)
  141. def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
  142. param_shapes,
  143. fp32_flat_groups,
  144. buffers):
  145. # Reconstruction protocol:
  146. #
  147. # XXX: document this
  148. if debug:
  149. for i in range(world_size):
  150. for j in range(len(fp32_flat_groups[0])):
  151. print(
  152. f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
  153. # XXX: memory usage doubles here (zero2)
  154. num_param_groups = len(fp32_flat_groups[0])
  155. merged_single_partition_of_fp32_groups = []
  156. for i in range(num_param_groups):
  157. merged_partitions = [sd[i] for sd in fp32_flat_groups]
  158. full_single_fp32_vector = torch.cat(merged_partitions, 0)
  159. merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
  160. avail_numel = sum([
  161. full_single_fp32_vector.numel()
  162. for full_single_fp32_vector in merged_single_partition_of_fp32_groups
  163. ])
  164. if debug:
  165. wanted_params = sum([len(shapes) for shapes in param_shapes])
  166. wanted_numel = sum(
  167. [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
  168. # not asserting if there is a mismatch due to possible padding
  169. print(f"Have {avail_numel} numels to process.")
  170. print(f"Need {wanted_numel} numels in {wanted_params} params.")
  171. state_dict = OrderedDict()
  172. # buffers
  173. state_dict.update(buffers)
  174. if debug:
  175. print(f"added {len(buffers)} buffers")
  176. # params
  177. # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
  178. # out-of-core computing solution
  179. total_numel = 0
  180. total_params = 0
  181. for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
  182. offset = 0
  183. avail_numel = full_single_fp32_vector.numel()
  184. for name, shape in shapes.items():
  185. unpartitioned_numel = shape.numel()
  186. total_numel += unpartitioned_numel
  187. total_params += 1
  188. if debug:
  189. print(
  190. f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
  191. )
  192. state_dict[name] = full_single_fp32_vector.narrow(
  193. 0,
  194. offset,
  195. unpartitioned_numel).view(shape)
  196. offset += unpartitioned_numel
  197. # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
  198. # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
  199. # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
  200. # live optimizer object, so we are checking that the numbers are within the right range
  201. align_to = 2 * world_size
  202. def zero2_align(x):
  203. return align_to * math.ceil(x / align_to)
  204. if debug:
  205. print(f"original offset={offset}, avail_numel={avail_numel}")
  206. offset = zero2_align(offset)
  207. avail_numel = zero2_align(avail_numel)
  208. if debug:
  209. print(f"aligned offset={offset}, avail_numel={avail_numel}")
  210. # Sanity check
  211. if offset != avail_numel:
  212. raise ValueError(
  213. f"consumed {offset} numels out of {avail_numel} - something is wrong")
  214. print(
  215. f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
  216. )
  217. return state_dict
  218. def zero3_partitioned_param_info(unpartitioned_numel, world_size):
  219. remainder = unpartitioned_numel % world_size
  220. padding_numel = (world_size - remainder) if remainder else 0
  221. partitioned_numel = math.ceil(unpartitioned_numel / world_size)
  222. return partitioned_numel, padding_numel
  223. def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
  224. param_shapes,
  225. fp32_flat_groups,
  226. buffers):
  227. # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
  228. # param, re-consolidating each param, while dealing with padding if any
  229. avail_numel = fp32_flat_groups[0].numel() * world_size
  230. # merge list of dicts, preserving order
  231. param_shapes = {k: v for d in param_shapes for k, v in d.items()}
  232. if debug:
  233. for i in range(world_size):
  234. print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
  235. wanted_params = len(param_shapes)
  236. wanted_numel = sum(shape.numel() for shape in param_shapes.values())
  237. # not asserting if there is a mismatch due to possible padding
  238. print(f"Have {avail_numel} numels to process.")
  239. print(f"Need {wanted_numel} numels in {wanted_params} params.")
  240. state_dict = OrderedDict()
  241. # buffers
  242. state_dict.update(buffers)
  243. if debug:
  244. print(f"added {len(buffers)} buffers")
  245. # params
  246. # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
  247. # out-of-core computing solution
  248. offset = 0
  249. total_numel = 0
  250. total_params = 0
  251. for name, shape in param_shapes.items():
  252. unpartitioned_numel = shape.numel()
  253. total_numel += unpartitioned_numel
  254. total_params += 1
  255. partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
  256. if debug:
  257. print(
  258. f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
  259. )
  260. # XXX: memory usage doubles here
  261. state_dict[name] = torch.cat(
  262. tuple(fp32_flat_groups[i].narrow(0,
  263. offset,
  264. partitioned_numel)
  265. for i in range(world_size)),
  266. 0).narrow(0,
  267. 0,
  268. unpartitioned_numel).view(shape)
  269. offset += partitioned_numel
  270. offset *= world_size
  271. # Sanity check
  272. if offset != avail_numel:
  273. raise ValueError(
  274. f"consumed {offset} numels out of {avail_numel} - something is wrong")
  275. print(
  276. f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
  277. )
  278. return state_dict
  279. def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
  280. """
  281. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
  282. ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
  283. via a model hub.
  284. Args:
  285. - ``checkpoint_dir``: path to the desired checkpoint folder
  286. - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
  287. Returns:
  288. - pytorch ``state_dict``
  289. Note: this approach may not work if your application doesn't have sufficient free CPU memory and
  290. you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
  291. the checkpoint.
  292. A typical usage might be ::
  293. from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
  294. # do the training and checkpoint saving
  295. state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
  296. model = model.cpu() # move to cpu
  297. model.load_state_dict(state_dict)
  298. # submit to model hub or save the model to share with others
  299. In this example the ``model`` will no longer be usable in the deepspeed context of the same
  300. application. i.e. you will need to re-initialize the deepspeed engine, since
  301. ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
  302. If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
  303. """
  304. if tag is None:
  305. latest_path = os.path.join(checkpoint_dir, 'latest')
  306. if os.path.isfile(latest_path):
  307. with open(latest_path, 'r') as fd:
  308. tag = fd.read().strip()
  309. else:
  310. raise ValueError(f"Unable to find 'latest' file at {latest_path}")
  311. ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
  312. if not os.path.isdir(ds_checkpoint_dir):
  313. raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
  314. return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
  315. def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
  316. """
  317. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
  318. loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
  319. Args:
  320. - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
  321. - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
  322. - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
  323. """
  324. state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
  325. print(f"Saving fp32 state dict to {output_file}")
  326. torch.save(state_dict, output_file)
  327. def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
  328. """
  329. 1. Put the provided model to cpu
  330. 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
  331. 3. Load it into the provided model
  332. Args:
  333. - ``model``: the model object to update
  334. - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
  335. - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
  336. Returns:
  337. - ``model`: modified model
  338. Make sure you have plenty of CPU memory available before you call this function. If you don't
  339. have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
  340. conveniently placed for you in the checkpoint folder.
  341. A typical usage might be ::
  342. from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
  343. model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
  344. # submit to model hub or save the model to share with others
  345. Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
  346. of the same application. i.e. you will need to re-initialize the deepspeed engine, since
  347. ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
  348. """
  349. logger.info(f"Extracting fp32 weights")
  350. state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
  351. logger.info(f"Overwriting model with fp32 weights")
  352. model = model.cpu()
  353. model.load_state_dict(state_dict, strict=False)
  354. return model
  355. if __name__ == "__main__":
  356. parser = argparse.ArgumentParser()
  357. parser.add_argument(
  358. "checkpoint_dir",
  359. type=str,
  360. help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
  361. parser.add_argument(
  362. "output_file",
  363. type=str,
  364. help=
  365. "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
  366. )
  367. parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
  368. args = parser.parse_args()
  369. debug = args.debug
  370. convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)