123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- """
- [1] Mastering Diverse Domains through World Models - 2023
- D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
- https://arxiv.org/pdf/2301.04104v1.pdf
- [2] Mastering Atari with Discrete World Models - 2021
- D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
- https://arxiv.org/pdf/2010.02193.pdf
- [3]
- D. Hafner's (author) original code repo (for JAX):
- https://github.com/danijar/dreamerv3
- """
- import unittest
- import gymnasium as gym
- import numpy as np
- import ray
- from ray.rllib.algorithms.dreamerv3 import dreamerv3
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.numpy import one_hot
- from ray.rllib.utils.test_utils import framework_iterator
- from ray import tune
- class TestDreamerV3(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init()
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_dreamerv3_compilation(self):
- """Test whether DreamerV3 can be built with all frameworks."""
- # Build a DreamerV3Config object.
- config = (
- dreamerv3.DreamerV3Config()
- .training(
- # Keep things simple. Especially the long dream rollouts seem
- # to take an enormous amount of time (initially).
- batch_size_B=4,
- horizon_H=5,
- batch_length_T=16,
- model_size="nano", # Use a tiny model for testing
- symlog_obs=True,
- use_float16=False,
- )
- .resources(
- num_learner_workers=2, # Try with 2 Learners.
- num_cpus_per_learner_worker=1,
- num_gpus_per_learner_worker=0,
- )
- )
- # TODO (sven): Add a `get_model_config` utility to AlgorithmConfig
- # that - for now - merges the user provided model_dict (which only
- # contains settings that only affect the model, e.g. model_size)
- # with the AlgorithmConfig-wide settings that are relevant for the model
- # (e.g. `batch_size_B`).
- # config.get_model_config()
- num_iterations = 2
- for _ in framework_iterator(config, frameworks="tf2"):
- for env in [
- "FrozenLake-v1",
- "CartPole-v1",
- "ALE/MsPacman-v5",
- "Pendulum-v1",
- ]:
- print("Env={}".format(env))
- # Add one-hot observations for FrozenLake env.
- if env == "FrozenLake-v1":
- def env_creator(ctx):
- import gymnasium as gym
- from ray.rllib.algorithms.dreamerv3.utils.env_runner import (
- OneHot,
- )
- return OneHot(gym.make("FrozenLake-v1"))
- tune.register_env("frozen-lake-one-hot", env_creator)
- env = "frozen-lake-one-hot"
- config.environment(env)
- algo = config.build()
- obs_space = algo.workers.local_worker().env.single_observation_space
- act_space = algo.workers.local_worker().env.single_action_space
- rl_module = algo.workers.local_worker().module
- for i in range(num_iterations):
- results = algo.train()
- print(results)
- # Test dream trajectory w/ recreated observations.
- sample = algo.replay_buffer.sample()
- dream = rl_module.dreamer_model.dream_trajectory_with_burn_in(
- start_states=rl_module.dreamer_model.get_initial_state(),
- timesteps_burn_in=5,
- timesteps_H=45,
- observations=sample["obs"][:1], # B=1
- actions=(
- one_hot(
- sample["actions"],
- depth=act_space.n,
- )
- if isinstance(act_space, gym.spaces.Discrete)
- else sample["actions"]
- )[
- :1
- ], # B=1
- )
- self.assertTrue(
- dream["actions_dreamed_t0_to_H_BxT"].shape
- == (46, 1)
- + (
- (act_space.n,)
- if isinstance(act_space, gym.spaces.Discrete)
- else tuple(act_space.shape)
- )
- )
- self.assertTrue(dream["continues_dreamed_t0_to_H_BxT"].shape == (46, 1))
- self.assertTrue(
- dream["observations_dreamed_t0_to_H_BxT"].shape
- == [46, 1] + list(obs_space.shape)
- )
- algo.stop()
- def test_dreamerv3_dreamer_model_sizes(self):
- """Tests, whether the different model sizes match the ones reported in [1]."""
- # For Atari, these are the exact numbers from the repo ([3]).
- # However, for CartPole + size "S" and "M", the author's original code will not
- # match for the world model count. This is due to the fact that the author uses
- # encoder/decoder nets with 5x1024 nodes (which corresponds to XL) regardless of
- # the `model_size` settings (iff >="S").
- expected_num_params_world_model = {
- "XS_cartpole": 2435076,
- "S_cartpole": 7493380,
- "M_cartpole": 16206084,
- "L_cartpole": 37802244,
- "XL_cartpole": 108353796,
- "XS_atari": 7538979,
- "S_atari": 15687811,
- "M_atari": 32461635,
- "L_atari": 68278275,
- "XL_atari": 181558659,
- }
- # All values confirmed against [3] (100% match).
- expected_num_params_actor = {
- # hidden=[1280, 256]
- # hidden_norm=[256], [256]
- # pi (2 actions)=[256, 2], [2]
- "XS_cartpole": 328706,
- "S_cartpole": 1051650,
- "M_cartpole": 2135042,
- "L_cartpole": 4136450,
- "XL_cartpole": 9449474,
- "XS_atari": 329734,
- "S_atari": 1053702,
- "M_atari": 2137606,
- "L_atari": 4139526,
- "XL_atari": 9453574,
- }
- # All values confirmed against [3] (100% match).
- expected_num_params_critic = {
- # hidden=[1280, 256]
- # hidden_norm=[256], [256]
- # vf (buckets)=[256, 255], [255]
- "XS_cartpole": 393727,
- "S_cartpole": 1181439,
- "M_cartpole": 2297215,
- "L_cartpole": 4331007,
- "XL_cartpole": 9708799,
- "XS_atari": 393727,
- "S_atari": 1181439,
- "M_atari": 2297215,
- "L_atari": 4331007,
- "XL_atari": 9708799,
- }
- config = dreamerv3.DreamerV3Config().training(
- batch_length_T=16,
- horizon_H=5,
- symlog_obs=True,
- )
- for _ in framework_iterator(config, frameworks="tf2"):
- # Check all model_sizes described in the paper ([1]) on matching the number
- # of parameters to RLlib's implementation.
- for model_size in ["XS", "S", "M", "L", "XL"]:
- config.model_size = model_size
- config.training(model={"model_size": model_size})
- # Atari and CartPole spaces.
- for obs_space, num_actions, env_name in [
- (gym.spaces.Box(-1.0, 0.0, (4,), np.float32), 2, "cartpole"),
- (gym.spaces.Box(-1.0, 0.0, (64, 64, 3), np.float32), 6, "atari"),
- ]:
- print(f"Testing model_size={model_size} on env-type: {env_name} ..")
- config.environment(
- observation_space=obs_space,
- action_space=gym.spaces.Discrete(num_actions),
- )
- # Create our RLModule to compute actions with.
- policy_dict, _ = config.get_multi_agent_setup()
- module_spec = config.get_marl_module_spec(policy_dict=policy_dict)
- rl_module = module_spec.build()[DEFAULT_POLICY_ID]
- # Count the generated RLModule's parameters and compare to the
- # paper's reported numbers ([1] and [3]).
- num_params_world_model = sum(
- np.prod(v.shape.as_list())
- for v in rl_module.world_model.trainable_variables
- )
- self.assertEqual(
- num_params_world_model,
- expected_num_params_world_model[f"{model_size}_{env_name}"],
- )
- num_params_actor = sum(
- np.prod(v.shape.as_list())
- for v in rl_module.actor.trainable_variables
- )
- self.assertEqual(
- num_params_actor,
- expected_num_params_actor[f"{model_size}_{env_name}"],
- )
- num_params_critic = sum(
- np.prod(v.shape.as_list())
- for v in rl_module.critic.trainable_variables
- )
- self.assertEqual(
- num_params_critic,
- expected_num_params_critic[f"{model_size}_{env_name}"],
- )
- print("\tok")
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|