test_dreamerv3.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. [1] Mastering Diverse Domains through World Models - 2023
  3. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  4. https://arxiv.org/pdf/2301.04104v1.pdf
  5. [2] Mastering Atari with Discrete World Models - 2021
  6. D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
  7. https://arxiv.org/pdf/2010.02193.pdf
  8. [3]
  9. D. Hafner's (author) original code repo (for JAX):
  10. https://github.com/danijar/dreamerv3
  11. """
  12. import unittest
  13. import gymnasium as gym
  14. import numpy as np
  15. import ray
  16. from ray.rllib.algorithms.dreamerv3 import dreamerv3
  17. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  18. from ray.rllib.utils.numpy import one_hot
  19. from ray.rllib.utils.test_utils import framework_iterator
  20. from ray import tune
  21. class TestDreamerV3(unittest.TestCase):
  22. @classmethod
  23. def setUpClass(cls):
  24. ray.init()
  25. @classmethod
  26. def tearDownClass(cls):
  27. ray.shutdown()
  28. def test_dreamerv3_compilation(self):
  29. """Test whether DreamerV3 can be built with all frameworks."""
  30. # Build a DreamerV3Config object.
  31. config = (
  32. dreamerv3.DreamerV3Config()
  33. .training(
  34. # Keep things simple. Especially the long dream rollouts seem
  35. # to take an enormous amount of time (initially).
  36. batch_size_B=4,
  37. horizon_H=5,
  38. batch_length_T=16,
  39. model_size="nano", # Use a tiny model for testing
  40. symlog_obs=True,
  41. use_float16=False,
  42. )
  43. .resources(
  44. num_learner_workers=2, # Try with 2 Learners.
  45. num_cpus_per_learner_worker=1,
  46. num_gpus_per_learner_worker=0,
  47. )
  48. )
  49. # TODO (sven): Add a `get_model_config` utility to AlgorithmConfig
  50. # that - for now - merges the user provided model_dict (which only
  51. # contains settings that only affect the model, e.g. model_size)
  52. # with the AlgorithmConfig-wide settings that are relevant for the model
  53. # (e.g. `batch_size_B`).
  54. # config.get_model_config()
  55. num_iterations = 2
  56. for _ in framework_iterator(config, frameworks="tf2"):
  57. for env in [
  58. "FrozenLake-v1",
  59. "CartPole-v1",
  60. "ALE/MsPacman-v5",
  61. "Pendulum-v1",
  62. ]:
  63. print("Env={}".format(env))
  64. # Add one-hot observations for FrozenLake env.
  65. if env == "FrozenLake-v1":
  66. def env_creator(ctx):
  67. import gymnasium as gym
  68. from ray.rllib.algorithms.dreamerv3.utils.env_runner import (
  69. OneHot,
  70. )
  71. return OneHot(gym.make("FrozenLake-v1"))
  72. tune.register_env("frozen-lake-one-hot", env_creator)
  73. env = "frozen-lake-one-hot"
  74. config.environment(env)
  75. algo = config.build()
  76. obs_space = algo.workers.local_worker().env.single_observation_space
  77. act_space = algo.workers.local_worker().env.single_action_space
  78. rl_module = algo.workers.local_worker().module
  79. for i in range(num_iterations):
  80. results = algo.train()
  81. print(results)
  82. # Test dream trajectory w/ recreated observations.
  83. sample = algo.replay_buffer.sample()
  84. dream = rl_module.dreamer_model.dream_trajectory_with_burn_in(
  85. start_states=rl_module.dreamer_model.get_initial_state(),
  86. timesteps_burn_in=5,
  87. timesteps_H=45,
  88. observations=sample["obs"][:1], # B=1
  89. actions=(
  90. one_hot(
  91. sample["actions"],
  92. depth=act_space.n,
  93. )
  94. if isinstance(act_space, gym.spaces.Discrete)
  95. else sample["actions"]
  96. )[
  97. :1
  98. ], # B=1
  99. )
  100. self.assertTrue(
  101. dream["actions_dreamed_t0_to_H_BxT"].shape
  102. == (46, 1)
  103. + (
  104. (act_space.n,)
  105. if isinstance(act_space, gym.spaces.Discrete)
  106. else tuple(act_space.shape)
  107. )
  108. )
  109. self.assertTrue(dream["continues_dreamed_t0_to_H_BxT"].shape == (46, 1))
  110. self.assertTrue(
  111. dream["observations_dreamed_t0_to_H_BxT"].shape
  112. == [46, 1] + list(obs_space.shape)
  113. )
  114. algo.stop()
  115. def test_dreamerv3_dreamer_model_sizes(self):
  116. """Tests, whether the different model sizes match the ones reported in [1]."""
  117. # For Atari, these are the exact numbers from the repo ([3]).
  118. # However, for CartPole + size "S" and "M", the author's original code will not
  119. # match for the world model count. This is due to the fact that the author uses
  120. # encoder/decoder nets with 5x1024 nodes (which corresponds to XL) regardless of
  121. # the `model_size` settings (iff >="S").
  122. expected_num_params_world_model = {
  123. "XS_cartpole": 2435076,
  124. "S_cartpole": 7493380,
  125. "M_cartpole": 16206084,
  126. "L_cartpole": 37802244,
  127. "XL_cartpole": 108353796,
  128. "XS_atari": 7538979,
  129. "S_atari": 15687811,
  130. "M_atari": 32461635,
  131. "L_atari": 68278275,
  132. "XL_atari": 181558659,
  133. }
  134. # All values confirmed against [3] (100% match).
  135. expected_num_params_actor = {
  136. # hidden=[1280, 256]
  137. # hidden_norm=[256], [256]
  138. # pi (2 actions)=[256, 2], [2]
  139. "XS_cartpole": 328706,
  140. "S_cartpole": 1051650,
  141. "M_cartpole": 2135042,
  142. "L_cartpole": 4136450,
  143. "XL_cartpole": 9449474,
  144. "XS_atari": 329734,
  145. "S_atari": 1053702,
  146. "M_atari": 2137606,
  147. "L_atari": 4139526,
  148. "XL_atari": 9453574,
  149. }
  150. # All values confirmed against [3] (100% match).
  151. expected_num_params_critic = {
  152. # hidden=[1280, 256]
  153. # hidden_norm=[256], [256]
  154. # vf (buckets)=[256, 255], [255]
  155. "XS_cartpole": 393727,
  156. "S_cartpole": 1181439,
  157. "M_cartpole": 2297215,
  158. "L_cartpole": 4331007,
  159. "XL_cartpole": 9708799,
  160. "XS_atari": 393727,
  161. "S_atari": 1181439,
  162. "M_atari": 2297215,
  163. "L_atari": 4331007,
  164. "XL_atari": 9708799,
  165. }
  166. config = dreamerv3.DreamerV3Config().training(
  167. batch_length_T=16,
  168. horizon_H=5,
  169. symlog_obs=True,
  170. )
  171. for _ in framework_iterator(config, frameworks="tf2"):
  172. # Check all model_sizes described in the paper ([1]) on matching the number
  173. # of parameters to RLlib's implementation.
  174. for model_size in ["XS", "S", "M", "L", "XL"]:
  175. config.model_size = model_size
  176. config.training(model={"model_size": model_size})
  177. # Atari and CartPole spaces.
  178. for obs_space, num_actions, env_name in [
  179. (gym.spaces.Box(-1.0, 0.0, (4,), np.float32), 2, "cartpole"),
  180. (gym.spaces.Box(-1.0, 0.0, (64, 64, 3), np.float32), 6, "atari"),
  181. ]:
  182. print(f"Testing model_size={model_size} on env-type: {env_name} ..")
  183. config.environment(
  184. observation_space=obs_space,
  185. action_space=gym.spaces.Discrete(num_actions),
  186. )
  187. # Create our RLModule to compute actions with.
  188. policy_dict, _ = config.get_multi_agent_setup()
  189. module_spec = config.get_marl_module_spec(policy_dict=policy_dict)
  190. rl_module = module_spec.build()[DEFAULT_POLICY_ID]
  191. # Count the generated RLModule's parameters and compare to the
  192. # paper's reported numbers ([1] and [3]).
  193. num_params_world_model = sum(
  194. np.prod(v.shape.as_list())
  195. for v in rl_module.world_model.trainable_variables
  196. )
  197. self.assertEqual(
  198. num_params_world_model,
  199. expected_num_params_world_model[f"{model_size}_{env_name}"],
  200. )
  201. num_params_actor = sum(
  202. np.prod(v.shape.as_list())
  203. for v in rl_module.actor.trainable_variables
  204. )
  205. self.assertEqual(
  206. num_params_actor,
  207. expected_num_params_actor[f"{model_size}_{env_name}"],
  208. )
  209. num_params_critic = sum(
  210. np.prod(v.shape.as_list())
  211. for v in rl_module.critic.trainable_variables
  212. )
  213. self.assertEqual(
  214. num_params_critic,
  215. expected_num_params_critic[f"{model_size}_{env_name}"],
  216. )
  217. print("\tok")
  218. if __name__ == "__main__":
  219. import pytest
  220. import sys
  221. sys.exit(pytest.main(["-v", __file__]))