utils.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX)
  7. def get_model_ckpt_name_for_rank(base_folder, mp_rank_str):
  8. ckpt_name = os.path.join(
  9. base_folder,
  10. MODEL_FILE_PREFIX + mp_rank_str + MODEL_FILE_SUFFIX,
  11. )
  12. return ckpt_name
  13. def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank):
  14. zero_prefix = f'{ZERO_FILE_PREFIX}{dp_rank}'
  15. mp_rank_string = f'_{MODEL_FILE_PREFIX}{mp_rank:02d}'
  16. zero_ckpt_name = os.path.join(
  17. base_folder,
  18. zero_prefix + mp_rank_string + OPTIM_FILE_SUFFIX,
  19. )
  20. return zero_ckpt_name
  21. def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank):
  22. ckpt_file = f'{layer_id}-model_{tp_rank:02d}{MODEL_FILE_SUFFIX}'
  23. ckpt_path = os.path.join(base_folder, ckpt_file)
  24. return ckpt_path
  25. # We pass cloned tensors to torch.save() to avoid checkpoint bloat that occurs when torch.save()
  26. # saves the underlying storage rather than the slice of the storage corresponding to individual tensors.
  27. # This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers.
  28. # Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size.
  29. # It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat.
  30. # See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing
  31. def clone_tensors_for_torch_save(item, device=torch.device('cpu')):
  32. """
  33. Returns a copy of ``item`` with all enclosed tensors replaced by clones on a specified device.
  34. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
  35. Parameters:
  36. - ``item``: tensor to clone or (possibly nested) container of tensors to clone.
  37. - ``device``: target device (defaults to 'cpu')
  38. Returns:
  39. - copy of ``item`` with cloned tensors on target device
  40. """
  41. if torch.is_tensor(item):
  42. return item.detach().clone().to(device)
  43. elif isinstance(item, list):
  44. return [clone_tensors_for_torch_save(v, device) for v in item]
  45. elif isinstance(item, tuple):
  46. return tuple([clone_tensors_for_torch_save(v, device) for v in item])
  47. elif isinstance(item, dict):
  48. return type(item)({k: clone_tensors_for_torch_save(v, device) for k, v in item.items()})
  49. else:
  50. return item