test_backward_compat.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. from pathlib import Path
  3. from packaging import version
  4. import sys
  5. import unittest
  6. import ray
  7. import ray.cloudpickle as pickle
  8. from ray.rllib.algorithms.algorithm import Algorithm
  9. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  10. from ray.rllib.algorithms.dqn import DQN
  11. from ray.rllib.algorithms.ppo import PPO
  12. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  13. from ray.rllib.policy.policy import Policy, PolicySpec
  14. from ray.rllib.utils.checkpoints import get_checkpoint_info
  15. from ray.rllib.utils.test_utils import framework_iterator
  16. from ray.tune.registry import register_env
  17. class TestBackwardCompatibility(unittest.TestCase):
  18. @classmethod
  19. def setUpClass(cls):
  20. ray.init(runtime_env={"pip_packages": ["gym==0.23.1"]})
  21. @classmethod
  22. def tearDownClass(cls):
  23. ray.shutdown()
  24. def test_old_checkpoint_formats(self):
  25. """Tests, whether we remain backward compatible (>=2.0.0) wrt checkpoints."""
  26. rllib_dir = Path(__file__).parent.parent.parent
  27. print(f"rllib dir={rllib_dir} exists={os.path.isdir(rllib_dir)}")
  28. # TODO: Once checkpoints are python version independent (once we stop using
  29. # pickle), add 1.0 here as well.
  30. # Broken due to gymnasium move (old gym envs not recoverable via pickle due to
  31. # gym version conflict (gym==0.23.x not compatible with gym==0.26.x)).
  32. for v in []: # "0.1"
  33. v = version.Version(v)
  34. for fw in framework_iterator():
  35. path_to_checkpoint = os.path.join(
  36. rllib_dir,
  37. "tests",
  38. "backward_compat",
  39. "checkpoints",
  40. "v" + str(v),
  41. "ppo_frozenlake_" + fw,
  42. )
  43. print(
  44. f"path_to_checkpoint={path_to_checkpoint} "
  45. f"exists={os.path.isdir(path_to_checkpoint)}"
  46. )
  47. checkpoint_info = get_checkpoint_info(path_to_checkpoint)
  48. # v0.1: Need to create algo first, then restore.
  49. if checkpoint_info["checkpoint_version"] == version.Version("0.1"):
  50. # For checkpoints <= v0.1, we need to magically know the original
  51. # config used as well as the algo class.
  52. with open(checkpoint_info["state_file"], "rb") as f:
  53. state = pickle.load(f)
  54. worker_state = pickle.loads(state["worker"])
  55. algo = PPO(config=worker_state["policy_config"])
  56. # Note, we can not use restore() here because the testing
  57. # checkpoints are created with Algorithm.save() by
  58. # checkpoints/create_checkpoints.py. I.e, they are missing
  59. # all the Tune checkpoint metadata.
  60. algo.load_checkpoint(path_to_checkpoint)
  61. # > v0.1: Simply use new `Algorithm.from_checkpoint()` staticmethod.
  62. else:
  63. algo = Algorithm.from_checkpoint(path_to_checkpoint)
  64. # Also test restoring a Policy from an algo checkpoint.
  65. policies = Policy.from_checkpoint(path_to_checkpoint)
  66. self.assertTrue("default_policy" in policies)
  67. print(algo.train())
  68. algo.stop()
  69. def test_old_algorithm_config_dicts(self):
  70. """Tests, whether we can build Algorithm objects with old config dicts."""
  71. config_dict = {
  72. "evaluation_config": {
  73. "lr": 0.1,
  74. },
  75. "lr": 0.2,
  76. # Old-style multi-agent dict.
  77. "multiagent": {
  78. "policies": {"pol1", "pol2"},
  79. "policies_to_train": ["pol1"],
  80. "policy_mapping_fn": lambda aid, episode, worker, **kwargs: "pol1",
  81. },
  82. }
  83. config = AlgorithmConfig.from_dict(config_dict)
  84. self.assertFalse(config.in_evaluation)
  85. self.assertTrue(config.lr == 0.2)
  86. self.assertTrue(config.policies == {"pol1", "pol2"})
  87. self.assertTrue(config.policy_mapping_fn(1, 2, 3) == "pol1")
  88. eval_config = config.get_evaluation_config_object()
  89. self.assertTrue(eval_config.in_evaluation)
  90. self.assertTrue(eval_config.lr == 0.1)
  91. register_env(
  92. "test",
  93. lambda ctx: MultiAgentCartPole(config={"num_agents": ctx["num_agents"]}),
  94. )
  95. config = {
  96. "env_config": {
  97. "num_agents": 1,
  98. },
  99. "lr": 0.001,
  100. "evaluation_config": {
  101. "num_envs_per_worker": 4,
  102. "explore": False,
  103. },
  104. "evaluation_num_workers": 1,
  105. "multiagent": {
  106. "policies": {
  107. "policy1": PolicySpec(),
  108. },
  109. "policy_mapping_fn": lambda aid, episode, worker, **kw: "policy1",
  110. "policies_to_train": ["policy1"],
  111. },
  112. }
  113. algo = DQN(config=config, env="test")
  114. self.assertTrue(algo.config.lr == 0.001)
  115. self.assertTrue(algo.config.evaluation_num_workers == 1)
  116. self.assertTrue(list(algo.config.policies.keys()) == ["policy1"])
  117. self.assertTrue(algo.config.explore is True)
  118. self.assertTrue(algo.evaluation_config.explore is False)
  119. print(algo.train())
  120. algo.stop()
  121. if __name__ == "__main__":
  122. import pytest
  123. sys.exit(pytest.main(["-v", __file__]))