test_learner_group_checkpointing.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import gymnasium as gym
  2. import itertools
  3. import numpy as np
  4. import tempfile
  5. import unittest
  6. import ray
  7. from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
  8. from ray.rllib.core.testing.utils import get_learner_group
  9. from ray.rllib.policy.sample_batch import SampleBatch
  10. from ray.rllib.utils.test_utils import check
  11. FAKE_BATCH = {
  12. SampleBatch.OBS: np.array(
  13. [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
  14. dtype=np.float32,
  15. ),
  16. SampleBatch.NEXT_OBS: np.array(
  17. [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
  18. dtype=np.float32,
  19. ),
  20. SampleBatch.ACTIONS: np.array([0, 1, 1]),
  21. SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
  22. SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
  23. SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
  24. SampleBatch.TERMINATEDS: np.array([False, False, True]),
  25. SampleBatch.TRUNCATEDS: np.array([False, False, False]),
  26. SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
  27. SampleBatch.ACTION_DIST_INPUTS: np.array(
  28. [[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32
  29. ),
  30. SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
  31. SampleBatch.EPS_ID: np.array([0, 0, 0]),
  32. SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
  33. }
  34. REMOTE_SCALING_CONFIGS = {
  35. "remote-cpu": LearnerGroupScalingConfig(num_workers=1),
  36. "remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1),
  37. "multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1),
  38. "multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2),
  39. # "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig(
  40. # num_workers=2, num_gpus_per_worker=2
  41. # ),
  42. }
  43. class TestLearnerGroupCheckpointing(unittest.TestCase):
  44. def setUp(self) -> None:
  45. ray.init()
  46. def tearDown(self) -> None:
  47. ray.shutdown()
  48. def test_save_load_state(self):
  49. fws = ["tf2", "torch"]
  50. scaling_modes = REMOTE_SCALING_CONFIGS.keys()
  51. test_iterator = itertools.product(fws, scaling_modes)
  52. batch = SampleBatch(FAKE_BATCH)
  53. for fw, scaling_mode in test_iterator:
  54. print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
  55. env = gym.make("CartPole-v1")
  56. scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
  57. initial_learner_group = get_learner_group(
  58. fw, env, scaling_config, eager_tracing=True
  59. )
  60. # checkpoint the initial learner state for later comparison
  61. initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
  62. initial_learner_group.save_state(initial_learner_checkpoint_dir)
  63. initial_learner_group_weights = initial_learner_group.get_weights()
  64. # do a single update
  65. initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)
  66. # checkpoint the learner state after 1 update for later comparison
  67. learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
  68. initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)
  69. # remove that learner, construct a new one, and load the state of the old
  70. # learner into the new one
  71. initial_learner_group.shutdown()
  72. del initial_learner_group
  73. new_learner_group = get_learner_group(
  74. fw, env, scaling_config, eager_tracing=True
  75. )
  76. new_learner_group.load_state(learner_after_1_update_checkpoint_dir)
  77. # do another update
  78. results_with_break = new_learner_group.update(
  79. batch.as_multi_agent(), reduce_fn=None
  80. )
  81. weights_after_1_update_with_break = new_learner_group.get_weights()
  82. new_learner_group.shutdown()
  83. del new_learner_group
  84. # construct a new learner group and load the initial state of the learner
  85. learner_group = get_learner_group(
  86. fw, env, scaling_config, eager_tracing=True
  87. )
  88. learner_group.load_state(initial_learner_checkpoint_dir)
  89. check(learner_group.get_weights(), initial_learner_group_weights)
  90. learner_group.update(batch.as_multi_agent(), reduce_fn=None)
  91. results_without_break = learner_group.update(
  92. batch.as_multi_agent(), reduce_fn=None
  93. )
  94. weights_after_1_update_without_break = learner_group.get_weights()
  95. learner_group.shutdown()
  96. del learner_group
  97. # compare the results of the two updates
  98. check(results_with_break, results_without_break)
  99. check(
  100. weights_after_1_update_with_break, weights_after_1_update_without_break
  101. )
  102. if __name__ == "__main__":
  103. import pytest
  104. import sys
  105. sys.exit(pytest.main(["-v", __file__]))