model-checkpointing.rst 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. Model Checkpointing
  2. ===================
  3. DeepSpeed provides routines for checkpointing model state during training.
  4. Loading Training Checkpoints
  5. ----------------------------
  6. .. autofunction:: deepspeed.DeepSpeedEngine.load_checkpoint
  7. Saving Training Checkpoints
  8. ---------------------------
  9. .. autofunction:: deepspeed.DeepSpeedEngine.save_checkpoint
  10. ZeRO Checkpoint fp32 Weights Recovery
  11. -------------------------------------
  12. DeepSpeed provides routines for extracting fp32 weights from the saved ZeRO checkpoint's optimizer states.
  13. .. autofunction:: deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
  14. .. autofunction:: deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint
  15. .. autofunction:: deepspeed.utils.zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict
  16. Avoiding ZeRO Checkpoint Bloat
  17. ------------------------------
  18. ZeRO stage 1 and 2 checkpoints created using ``torch.save()`` can sometimes be larger than expected. This bloat
  19. is caused by the interaction of ZeRO's tensor flattening and torch's tensor `storage management <https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing>`_ .
  20. You can avoid this problem by using the ``clone_tensors_for_torch_save`` utility of DeepSpeed as illustrated below.
  21. .. autofunction:: deepspeed.checkpoint.utils.clone_tensors_for_torch_save
  22. The following code snippet illustrates this functionality for creating a HuggingFace model checkpoint:
  23. .. code-block:: python
  24. ds_config = {
  25. ...
  26. }
  27. model = AutoModelForCausalLM.from_pretrained("facebook/opt-13b", torch_dtype=torch.float16)
  28. ds_engine, _, _, _ = deepspeed.initialize(model=model, config_params=ds_config)
  29. lean_state_dict = deepspeed.checkpoint.utils.clone_tensors_for_torch_save(ds_engine.module.state_dict())
  30. ds_engine.module.save_pretrained("lean_after", state_dict=lean_state_dict)