test_latest_checkpoint.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. import pytest
  6. from unit.common import DistributedTest
  7. from unit.simple_model import *
  8. from unit.checkpoint.common import checkpoint_correctness_verification
  9. from deepspeed.ops.op_builder import FusedAdamBuilder
  10. if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
  11. pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
  12. class TestLatestCheckpoint(DistributedTest):
  13. world_size = 1
  14. def test_existing_latest(self, tmpdir):
  15. config_dict = {
  16. "train_batch_size": 2,
  17. "steps_per_print": 1,
  18. "optimizer": {
  19. "type": "Adam",
  20. "params": {
  21. "lr": 0.00015
  22. }
  23. }
  24. }
  25. hidden_dim = 10
  26. models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]
  27. checkpoint_correctness_verification(config_dict=config_dict,
  28. models=models,
  29. hidden_dim=hidden_dim,
  30. tmpdir=tmpdir,
  31. load_optimizer_states=True,
  32. load_lr_scheduler_states=False,
  33. fp16=False,
  34. empty_tag=True)
  35. def test_missing_latest(self, tmpdir):
  36. config_dict = {
  37. "train_batch_size": 2,
  38. "steps_per_print": 1,
  39. "optimizer": {
  40. "type": "Adam",
  41. "params": {
  42. "lr": 0.00015
  43. }
  44. }
  45. }
  46. hidden_dim = 10
  47. model = SimpleModel(hidden_dim)
  48. model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
  49. # should be no-op, since latest doesn't exist
  50. model.load_checkpoint(tmpdir)