test_get_optim_files.py 684 B

12345678910111213141516171819202122
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import pytest
  6. from deepspeed.utils.zero_to_fp32 import get_optim_files
  7. @pytest.mark.parametrize('num_checkpoints', [1, 2, 12, 24])
  8. def test_get_optim_files(tmpdir, num_checkpoints):
  9. saved_files = []
  10. for i in range(num_checkpoints):
  11. file_name = "zero_" + str(i) + "_optim_states.pt"
  12. path_name = os.path.join(tmpdir, file_name)
  13. saved_files.append(path_name)
  14. with open(path_name, "w") as f:
  15. f.write(file_name)
  16. loaded_files = get_optim_files(tmpdir)
  17. for lf, sf in zip(loaded_files, saved_files):
  18. assert lf == sf